コード例 #1
0
ファイル: inpaint_model.py プロジェクト: Ir1d/tf-models
    def call(self, x, mask, training):
        """Inpaint network.

        Args:
            x: incomplete image, [-1, 1]
            mask: mask region {0, 1}
        Returns:
            [-1, 1] as predicted image
        """
        xin = x
        offset_flow = None
        ones_x = tf.ones_like(x)[:, :, :, 0:1]
        x = tf.concat([x, ones_x, ones_x * mask], axis=3)

        cnum = 48

        # stage1
        x = self.s1_1(x)
        mask_s = resize_mask_like(mask, x)
        x = self.s1_2(x)
        x = tf.keras.activations.tanh(x)

        x_stage1 = x

        # stage2
        # paste result as input
        x = x * mask + xin[:, :, :, 0:3] * (1. - mask)
        x.set_shape(xin[:, :, :, 0:3].get_shape().as_list())

        # conv branch
        xnow = x
        x = self.s2(xnow)
        x_hallu = x

        # attention branch
        x = self.attn(xnow)
        x, offset_flow = contextual_attention(x, x, mask_s, 3, 1, rate=2)
        x = self.attn_2(x)
        pm = x
        x = tf.concat([x_hallu, pm], axis=3)

        # final part
        x_stage2 = self.final(x)
        x_stage2 = tf.keras.activations.tanh(x_stage2)

        return x_stage1, x_stage2, offset_flow
コード例 #2
0
    def build_inpaint_net(self, x, mask, reuse=False,
                          training=True, padding='SAME', name='inpaint_net'):
        """Inpaint network.

        Args:
            x: incomplete image, [-1, 1]
            mask: mask region {0, 1}
        Returns:
            [-1, 1] as predicted image
        """
        xin = x
        offset_flow = None
        ones_x = tf.ones_like(x)[:, :, :, 0:1]
        x = tf.concat([x, ones_x, ones_x*mask], axis=3)

        # two stage network
        cnum = 48
        with tf.variable_scope(name, reuse=reuse), \
                arg_scope([gen_conv, gen_deconv],
                          training=training, padding=padding):
            # stage1
            x = gen_conv(x, cnum, 5, 1, name='conv1')
            x = gen_conv(x, 2*cnum, 3, 2, name='conv2_downsample')
            x = gen_conv(x, 2*cnum, 3, 1, name='conv3')
            x = gen_conv(x, 4*cnum, 3, 2, name='conv4_downsample')
            x = gen_conv(x, 4*cnum, 3, 1, name='conv5')
            x = gen_conv(x, 4*cnum, 3, 1, name='conv6')
            mask_s = resize_mask_like(mask, x)
            x = gen_conv(x, 4*cnum, 3, rate=2, name='conv7_atrous')
            x = gen_conv(x, 4*cnum, 3, rate=4, name='conv8_atrous')
            x = gen_conv(x, 4*cnum, 3, rate=8, name='conv9_atrous')
            x = gen_conv(x, 4*cnum, 3, rate=16, name='conv10_atrous')
            x = gen_conv(x, 4*cnum, 3, 1, name='conv11')
            x = gen_conv(x, 4*cnum, 3, 1, name='conv12')
            x = gen_deconv(x, 2*cnum, name='conv13_upsample')
            x = gen_conv(x, 2*cnum, 3, 1, name='conv14')
            x = gen_deconv(x, cnum, name='conv15_upsample')
            x = gen_conv(x, cnum//2, 3, 1, name='conv16')
            x = gen_conv(x, 3, 3, 1, activation=None, name='conv17')
            x = tf.nn.tanh(x)
            x_stage1 = x

            # stage2, paste result as input
            x = x*mask + xin[:, :, :, 0:3]*(1.-mask)
            x.set_shape(xin[:, :, :, 0:3].get_shape().as_list())
            # conv branch
            # xnow = tf.concat([x, ones_x, ones_x*mask], axis=3)
            xnow = x
            x = gen_conv(xnow, cnum, 5, 1, name='xconv1')
            x = gen_conv(x, cnum, 3, 2, name='xconv2_downsample')
            x = gen_conv(x, 2*cnum, 3, 1, name='xconv3')
            x = gen_conv(x, 2*cnum, 3, 2, name='xconv4_downsample')
            x = gen_conv(x, 4*cnum, 3, 1, name='xconv5')
            x = gen_conv(x, 4*cnum, 3, 1, name='xconv6')
            x = gen_conv(x, 4*cnum, 3, rate=2, name='xconv7_atrous')
            x = gen_conv(x, 4*cnum, 3, rate=4, name='xconv8_atrous')
            x = gen_conv(x, 4*cnum, 3, rate=8, name='xconv9_atrous')
            x = gen_conv(x, 4*cnum, 3, rate=16, name='xconv10_atrous')
            x_hallu = x
            # attention branch
            x = gen_conv(xnow, cnum, 5, 1, name='pmconv1')
            x = gen_conv(x, cnum, 3, 2, name='pmconv2_downsample')
            x = gen_conv(x, 2*cnum, 3, 1, name='pmconv3')
            x = gen_conv(x, 4*cnum, 3, 2, name='pmconv4_downsample')
            x = gen_conv(x, 4*cnum, 3, 1, name='pmconv5')
            x = gen_conv(x, 4*cnum, 3, 1, name='pmconv6',
                                activation=tf.nn.relu)
            x, offset_flow = contextual_attention(x, x, mask_s, 3, 1, rate=2)
            x = gen_conv(x, 4*cnum, 3, 1, name='pmconv9')
            x = gen_conv(x, 4*cnum, 3, 1, name='pmconv10')
            pm = x
            x = tf.concat([x_hallu, pm], axis=3)

            x = gen_conv(x, 4*cnum, 3, 1, name='allconv11')
            x = gen_conv(x, 4*cnum, 3, 1, name='allconv12')
            x = gen_deconv(x, 2*cnum, name='allconv13_upsample')
            x = gen_conv(x, 2*cnum, 3, 1, name='allconv14')
            x = gen_deconv(x, cnum, name='allconv15_upsample')
            x = gen_conv(x, cnum//2, 3, 1, name='allconv16')
            x = gen_conv(x, 3, 3, 1, activation=None, name='allconv17')
            x = tf.nn.tanh(x)
            x_stage2 = x
        return x_stage1, x_stage2, offset_flow
コード例 #3
0
    def build_inpaint_net(self,
                          x,
                          mask,
                          config=None,
                          reuse=False,
                          training=True,
                          padding='SAME',
                          name='inpaint_net'):

        xin = x
        offset_flow = None
        ones_x = tf.ones_like(x)[:, :, :, 0:1]
        x = tf.concat([x, ones_x, ones_x * mask], axis=3)

        cnum = 32
        with tf.variable_scope(name, reuse=reuse), \
                arg_scope([gen_conv, gen_deconv],
                          training=training, padding=padding):
            x = gen_conv(x, cnum, 5, 1, name='conv1')
            x = gen_conv(x, 2 * cnum, 3, 2, name='conv2_downsample')
            x = gen_conv(x, 2 * cnum, 3, 1, name='conv3')
            x = gen_conv(x, 4 * cnum, 3, 2, name='conv4_downsample')
            x = gen_conv(x, 4 * cnum, 3, 1, name='conv5')
            x = gen_conv(x, 4 * cnum, 3, 1, name='conv6')
            mask_s = resize_mask_like(mask, x)
            x = gen_conv(x, 4 * cnum, 3, rate=2, name='conv7_atrous')
            x = gen_conv(x, 4 * cnum, 3, rate=4, name='conv8_atrous')
            x = gen_conv(x, 4 * cnum, 3, rate=8, name='conv9_atrous')
            x = gen_conv(x, 4 * cnum, 3, rate=16, name='conv10_atrous')
            x = gen_conv(x, 4 * cnum, 3, 1, name='conv11')
            x = gen_conv(x, 4 * cnum, 3, 1, name='conv12')
            x = gen_deconv(x, 2 * cnum, name='conv13_upsample')
            x = gen_conv(x, 2 * cnum, 3, 1, name='conv14')
            x = gen_deconv(x, cnum, name='conv15_upsample')
            x = gen_conv(x, cnum // 2, 3, 1, name='conv16')
            x = gen_conv(x, 3, 3, 1, activation=None, name='conv17')
            x = tf.clip_by_value(x, -1., 1.)
            x_stage1 = x
            x = x * (mask) + xin * (1. - mask)
            x.set_shape(xin.get_shape().as_list())
            xnow = tf.concat([x, ones_x, ones_x * mask], axis=3)
            x = gen_conv(xnow, cnum, 5, 1, name='xconv1')
            x = gen_conv(x, cnum, 3, 2, name='xconv2_downsample')
            x = gen_conv(x, 2 * cnum, 3, 1, name='xconv3')
            x = gen_conv(x, 2 * cnum, 3, 2, name='xconv4_downsample')
            x = gen_conv(x, 4 * cnum, 3, 1, name='xconv5')
            x = gen_conv(x, 4 * cnum, 3, 1, name='xconv6')
            x = gen_conv(x, 4 * cnum, 3, rate=2, name='xconv7_atrous')
            x = gen_conv(x, 4 * cnum, 3, rate=4, name='xconv8_atrous')
            x = gen_conv(x, 4 * cnum, 3, rate=8, name='xconv9_atrous')
            x = gen_conv(x, 4 * cnum, 3, rate=16, name='xconv10_atrous')
            x_hallu = x
            x = gen_conv(xnow, cnum, 5, 1, name='pmconv1')
            x = gen_conv(x, cnum, 3, 2, name='pmconv2_downsample')
            x = gen_conv(x, 2 * cnum, 3, 1, name='pmconv3')
            x = gen_conv(x, 4 * cnum, 3, 2, name='pmconv4_downsample')
            x = gen_conv(x, 4 * cnum, 3, 1, name='pmconv5')
            x = gen_conv(x,
                         4 * cnum,
                         3,
                         1,
                         name='pmconv6',
                         activation=tf.nn.relu)
            x, offset_flow = contextual_attention(x, x, mask_s, 3, 1, rate=2)
            x = gen_conv(x, 4 * cnum, 3, 1, name='pmconv9')
            x = gen_conv(x, 4 * cnum, 3, 1, name='pmconv10')
            pm = x
            x = tf.concat([x_hallu, pm], axis=3)

            x = gen_conv(x, 4 * cnum, 3, 1, name='allconv11')
            x = gen_conv(x, 4 * cnum, 3, 1, name='allconv12')
            x = gen_deconv(x, 2 * cnum, name='allconv13_upsample')
            x = gen_conv(x, 2 * cnum, 3, 1, name='allconv14')
            x = gen_deconv(x, cnum, name='allconv15_upsample')
            x = gen_conv(x, cnum // 2, 3, 1, name='allconv16')
            x = tf.layers.conv2d(x,
                                 3,
                                 3,
                                 1,
                                 activation=None,
                                 dilation_rate=1,
                                 padding='SAME',
                                 name='allconv17')
            x_stage2 = tf.clip_by_value(x, -1., 1.)
        return x_stage1, x_stage2, offset_flow
コード例 #4
0
    def build_inpaint_net(self, x, mask, config=None, reuse=False,
                          training=True, padding='SAME', name='inpaint_net'):
        """Inpaint network.
`
        Args:
            x: incomplete image, [-1, 1]
            mask: mask region {0, 1}
        Returns:
            [-1, 1] as predicted image
        """
        print("original_x", x.shape)
        print("original_mask", mask.shape)
        xin = x
        offset_flow = None
        ones_x = tf.ones_like(x)[:, :, :, 0:1]
        print("ones_x", ones_x.shape)
        x = tf.concat([x, ones_x, ones_x*mask], axis=3)
        print("xxxx", x.shape)
        # two stage network
        cnum = 32
        with tf.variable_scope(name, reuse=reuse), arg_scope([gen_conv, gen_deconv], training=training, padding=padding):
            # stage1
            x = gen_conv(x, cnum, 5, 1, name='conv1')
            print("conv1", x.shape)
            x = gen_conv(x, 2*cnum, 3, 2, name='conv2_downsample')
            x = gen_conv(x, 2*cnum, 3, 1, name='conv3')
            x = gen_conv(x, 4*cnum, 3, 2, name='conv4_downsample')
            x = gen_conv(x, 4*cnum, 3, 1, name='conv5')
            x = gen_conv(x, 4*cnum, 3, 1, name='conv6')
            mask_s = resize_mask_like(mask, x)
            x = gen_conv(x, 4*cnum, 3, rate=2, name='conv7_atrous')
            x = gen_conv(x, 4*cnum, 3, rate=4, name='conv8_atrous')
            x = gen_conv(x, 4*cnum, 3, rate=8, name='conv9_atrous')
            x = gen_conv(x, 4*cnum, 3, rate=16, name='conv10_atrous')
            x = gen_conv(x, 4*cnum, 3, 1, name='conv11')
            x = gen_conv(x, 4*cnum, 3, 1, name='conv12')
            x = gen_deconv(x, 2*cnum, name='conv13_upsample')
            x = gen_conv(x, 2*cnum, 3, 1, name='conv14')
            x = gen_deconv(x, cnum, name='conv15_upsample')
            x = gen_conv(x, cnum//2, 3, 1, name='conv16')
            print("conv16", x.shape)
            x = gen_conv(x, 3, 3, 1, activation=None, name='conv17')
            print("conv17", x.shape)
            x = tf.clip_by_value(x, -1., 1.)  # (32 28 28 3)
            print("x_clip", x.shape)
            x_stage1 = x
            print("x_stage", x.shape)  # (32 28 28 3)
            # return x_stage1, None, None

            # stage2, paste result as input
            # x = tf.stop_gradient(x)
            x = x*mask + xin*(1.-mask)
            print("x_mask", x.shape)
            # 这里有问题
            x.set_shape(xin.get_shape().as_list())

            # conv branch
            xnow = tf.concat([x, ones_x, ones_x*mask], axis=3)
            print('xnow', xnow)
            x = gen_conv(xnow, cnum, 5, 1, name='xconv1')
            x = gen_conv(x, cnum, 3, 2, name='xconv2_downsample')
            x = gen_conv(x, 2*cnum, 3, 1, name='xconv3')
            x = gen_conv(x, 2*cnum, 3, 2, name='xconv4_downsample')
            x = gen_conv(x, 4*cnum, 3, 1, name='xconv5')
            x = gen_conv(x, 4*cnum, 3, 1, name='xconv6')
            x = gen_conv(x, 4*cnum, 3, rate=2, name='xconv7_atrous')
            x = gen_conv(x, 4*cnum, 3, rate=4, name='xconv8_atrous')
            x = gen_conv(x, 4*cnum, 3, rate=8, name='xconv9_atrous')
            x = gen_conv(x, 4*cnum, 3, rate=16, name='xconv10_atrous')
            x_hallu = x
            # attention branch
            x = gen_conv(xnow, cnum, 5, 1, name='pmconv1')
            x = gen_conv(x, cnum, 3, 2, name='pmconv2_downsample')
            x = gen_conv(x, 2*cnum, 3, 1, name='pmconv3')
            x = gen_conv(x, 4*cnum, 3, 2, name='pmconv4_downsample')
            x = gen_conv(x, 4*cnum, 3, 1, name='pmconv5')
            x = gen_conv(x, 4*cnum, 3, 1, name='pmconv6',
                         activation=tf.nn.relu)
            x, offset_flow = contextual_attention(x, x, mask_s, 3, 1, rate=2)
            x = gen_conv(x, 4*cnum, 3, 1, name='pmconv9')
            x = gen_conv(x, 4*cnum, 3, 1, name='pmconv10')
            pm = x
            x = tf.concat([x_hallu, pm], axis=3)

            x = gen_conv(x, 4*cnum, 3, 1, name='allconv11')
            x = gen_conv(x, 4*cnum, 3, 1, name='allconv12')
            x = gen_deconv(x, 2*cnum, name='allconv13_upsample')
            x = gen_conv(x, 2*cnum, 3, 1, name='allconv14')
            x = gen_deconv(x, cnum, name='allconv15_upsample')
            x = gen_conv(x, cnum//2, 3, 1, name='allconv16')
            x = gen_conv(x, 3, 3, 1, activation=None, name='allconv17')
            x_stage2 = tf.clip_by_value(x, -1., 1.)
        return x_stage1, x_stage2, offset_flow
コード例 #5
0
    def build_inpaint_net(self,
                          x,
                          mask,
                          config=None,
                          reuse=False,
                          training=True,
                          padding='SAME',
                          name='inpaint_net'):
        """Inpaint network.

        Args:
            x: incomplete image, [-1, 1]
            mask: mask region {0, 1}
        Returns:
            [-1, 1] as predicted image
        """
        xin = x
        offset_flow = None
        ones_x = tf.ones_like(x)[:, :, :, 0:1]
        x = tf.concat([x, ones_x, ones_x * mask], axis=3)

        # two stage network
        cnum = 32
        with tf.variable_scope(name, reuse=reuse), \
                arg_scope([gen_conv, gen_deconv],
                          training=training, padding=padding):
            # stage1
            x = gen_conv(x, cnum, 5, 1, name='conv1')
            x = gen_conv(x, 2 * cnum, 3, 2, name='conv2_downsample')
            x = gen_conv(x, 2 * cnum, 3, 1, name='conv3')
            x = gen_conv(x, 4 * cnum, 3, 2, name='conv4_downsample')
            x = gen_conv(x, 4 * cnum, 3, 1, name='conv5')
            x = gen_conv(x, 4 * cnum, 3, 1, name='conv6')
            mask_s = resize_mask_like(mask, x)
            x = gen_conv(x, 4 * cnum, 3, rate=2, name='conv7_atrous')
            x = gen_conv(x, 4 * cnum, 3, rate=4, name='conv8_atrous')
            x = gen_conv(x, 4 * cnum, 3, rate=8, name='conv9_atrous')
            x = gen_conv(x, 4 * cnum, 3, rate=16, name='conv10_atrous')
            x = gen_conv(x, 4 * cnum, 3, 1, name='conv11')
            x = gen_conv(x, 4 * cnum, 3, 1, name='conv12')
            x = gen_deconv(x, 2 * cnum, name='conv13_upsample')
            x = gen_conv(x, 2 * cnum, 3, 1, name='conv14')
            x = gen_deconv(x, cnum, name='conv15_upsample')
            x = gen_conv(x, cnum // 2, 3, 1, name='conv16')
            x = gen_conv(x, 3, 3, 1, activation=None, name='conv17')
            x = tf.clip_by_value(x, -1., 1.)
            x_stage1 = x
            # return x_stage1, None, None

            # stage2, paste result as input
            # x = tf.stop_gradient(x)
            x = x * mask + xin * (1. - mask)
            x_gb_in = x
            x.set_shape(xin.get_shape().as_list())
            # conv branch
            xnow = tf.concat([x, ones_x, ones_x * mask], axis=3)
            x = gen_conv(xnow, cnum, 5, 1, name='xconv1')
            x = gen_conv(x, cnum, 3, 2, name='xconv2_downsample')
            x = gen_conv(x, 2 * cnum, 3, 1, name='xconv3')
            x_feat1 = x
            x = gen_conv(x, 2 * cnum, 3, 2, name='xconv4_downsample')
            x = gen_conv(x, 4 * cnum, 3, 1, name='xconv5')
            x = gen_conv(x, 4 * cnum, 3, 1, name='xconv6')
            x_feat2 = x
            x = gen_conv(x, 4 * cnum, 3, rate=2, name='xconv7_atrous')
            x = gen_conv(x, 4 * cnum, 3, rate=4, name='xconv8_atrous')
            x = gen_conv(x, 4 * cnum, 3, rate=8, name='xconv9_atrous')
            x = gen_conv(x, 4 * cnum, 3, rate=16, name='xconv10_atrous')
            x_hallu = x
            x_feat3 = x

            # attention branch
            x = gen_conv(xnow, cnum, 5, 1, name='pmconv1')
            x = gen_conv(x, cnum, 3, 2, name='pmconv2_downsample')
            x = gen_conv(x, 2 * cnum, 3, 1, name='pmconv3')
            x = gen_conv(x, 4 * cnum, 3, 2, name='pmconv4_downsample')
            x = gen_conv(x, 4 * cnum, 3, 1, name='pmconv5')
            x = gen_conv(x,
                         4 * cnum,
                         3,
                         1,
                         name='pmconv6',
                         activation=tf.nn.relu)
            x, offset_flow = contextual_attention(x, x, mask_s, 3, 1, rate=2)
            x = gen_conv(x, 4 * cnum, 3, 1, name='pmconv9')
            x = gen_conv(x, 4 * cnum, 3, 1, name='pmconv10')
            pm = x

            # gradient branch
            if config.ADD_GRADIENT_BRANCH:
                grad_x_stage1 = self.get_grad.get_gradient_tf(x_gb_in)
                x = gen_conv(grad_x_stage1, cnum, 5, 1, name='gbconv1')
                x = gen_conv(x, 2 * cnum, 3, 2, name='gbconv2_downsample')
                x = gen_conv(x, 2 * cnum, 3, 1, name='gbconv3')
                x = tf.concat([x, x_feat1], axis=3)  # 融合主网络的修复信息
                x = gen_conv(x, 2 * cnum, 3, 1, name='gbfushion1')

                x = gen_conv(x, 4 * cnum, 3, 2, name='gbconv4_downsample')
                x = gen_conv(x, 4 * cnum, 3, 1, name='gbconv5')
                x = gen_conv(x, 4 * cnum, 3, 1, name='gbconv6')
                x = tf.concat([x, x_feat2], axis=3)
                x = gen_conv(x, 4 * cnum, 3, 1, name='gbfushiuon2')

                x = gen_conv(x, 4 * cnum, 3, rate=2, name='gbconv7_atrous')
                x = gen_conv(x, 4 * cnum, 3, rate=4, name='gbconv8_atrous')
                x = gen_conv(x, 4 * cnum, 3, rate=8, name='gbconv9_atrous')
                x = gen_conv(x, 4 * cnum, 3, rate=16, name='gbconv10_atrous')
                x = tf.concat([x, x_feat3], axis=3)
                x = gen_conv(x, 4 * cnum, 3, 1, name='gbfushiuon3')
                gb = x

                x = gen_conv(x, 4 * cnum, 3, 1, name='gbconv11')
                x = gen_conv(x, 4 * cnum, 3, 1, name='gbconv12')
                x = gen_deconv(x, 2 * cnum, name='gbconv13_upsample')
                x = gen_conv(x, 2 * cnum, 3, 1, name='gbconv14')
                x = gen_deconv(x, cnum, name='gbconv15_upsample')
                x = gen_conv(x, cnum // 2, 3, 1, name='gbconv16')
                x = gen_conv(x, 3, 3, 1, activation=None, name='gbconv17')
                x_gb = tf.clip_by_value(x, -1., 1.)

                x = tf.concat([x_hallu, pm, gb], axis=3)
                x = gen_conv(x, 4 * cnum, 3, 1, name='allconv11')
                x = gen_conv(x, 4 * cnum, 3, 1, name='allconv12')
                x = gen_deconv(x, 2 * cnum, name='allconv13_upsample')
                x = gen_conv(x, 2 * cnum, 3, 1, name='allconv14')
                x = gen_deconv(x, cnum, name='allconv15_upsample')
                x = gen_conv(x, cnum // 2, 3, 1, name='allconv16')
                x = gen_conv(x, 3, 3, 1, activation=None, name='allconv17')
                x_stage2 = tf.clip_by_value(x, -1., 1.)

                return x_stage1, x_stage2, x_gb, grad_x_stage1, offset_flow

            else:

                x = tf.concat([x_hallu, pm], axis=3)
                x = gen_conv(x, 4 * cnum, 3, 1, name='allconv11')
                x = gen_conv(x, 4 * cnum, 3, 1, name='allconv12')
                x = gen_deconv(x, 2 * cnum, name='allconv13_upsample')
                x = gen_conv(x, 2 * cnum, 3, 1, name='allconv14')
                x = gen_deconv(x, cnum, name='allconv15_upsample')
                x = gen_conv(x, cnum // 2, 3, 1, name='allconv16')
                x = gen_conv(x, 3, 3, 1, activation=None, name='allconv17')
                x_stage2 = tf.clip_by_value(x, -1., 1.)
                return x_stage1, x_stage2, offset_flow
コード例 #6
0
ファイル: inpaint_model.py プロジェクト: semuse25/SikMachUI
    def build_v2_inpaint_net(self, x, mask, config=None, reuse=False,
                          training=True, padding='SAME', name='inpaint_net'):
        """Inpaint network.

        Args:
            x: incomplete image, [-1, 1]
            mask: mask region {0, 1}
        Returns:
            [-1, 1] as predicted image
        """
        xin = x
        offset_flow = None
        ones_x = tf.ones_like(x)[:, :, :, 0:1]
        x = tf.concat([x, ones_x, ones_x*mask], axis=3)
        #images_summary(x, 'dbg', 3)

        # two stage network
        cnum = 24 # reduce? 25% of weights(according to paper).
        with tf.variable_scope(name, reuse=reuse), \
                arg_scope([gated_conv, gated_deconv, gen_conv], 
                          training=training, padding=padding):
            # stage1
            x = gated_conv(x, cnum, 5, 1, name='conv1')
            x = gated_conv(x, 2*cnum, 3, 2, name='conv2_downsample')
            x = gated_conv(x, 2*cnum, 3, 1, name='conv3')
            x = gated_conv(x, 4*cnum, 3, 2, name='conv4_downsample')
            x = gated_conv(x, 4*cnum, 3, 1, name='conv5')
            x = gated_conv(x, 4*cnum, 3, 1, name='conv6')
            mask_s = resize_mask_like(mask, x)
            x = gated_conv(x, 4*cnum, 3, rate=2, name='conv7_atrous')
            x = gated_conv(x, 4*cnum, 3, rate=4, name='conv8_atrous')
            x = gated_conv(x, 4*cnum, 3, rate=8, name='conv9_atrous')
            x = gated_conv(x, 4*cnum, 3, rate=16, name='conv10_atrous')
            x = gated_conv(x, 4*cnum, 3, 1, name='conv11')
            x = gated_conv(x, 4*cnum, 3, 1, name='conv12')
            x = gated_deconv(x, 2*cnum, name='conv13_upsample')
            x = gated_conv(x, 2*cnum, 3, 1, name='conv14')
            x = gated_deconv(x, cnum, name='conv15_upsample')
            x = gated_conv(x, cnum//2, 3, 1, name='conv16')
            x = gen_conv(x, 3, 3, 1, activation=None, name='conv17') 
            #   ~~~ last layer: no activation: not gated but just conv
            x = tf.clip_by_value(x, -1., 1.)
            x_stage1 = x
            # return x_stage1, None, None

            # stage2, paste result as input
            # x = tf.stop_gradient(x)
            x = x*mask + xin*(1.-mask)
            x.set_shape(xin.get_shape().as_list())
            # conv branch
            xnow = tf.concat([x, ones_x, ones_x*mask], axis=3)
            x = gated_conv(xnow, cnum, 5, 1, name='xconv1')
            x = gated_conv(x, cnum, 3, 2, name='xconv2_downsample')
            x = gated_conv(x, 2*cnum, 3, 1, name='xconv3')
            x = gated_conv(x, 2*cnum, 3, 2, name='xconv4_downsample')
            x = gated_conv(x, 4*cnum, 3, 1, name='xconv5')
            x = gated_conv(x, 4*cnum, 3, 1, name='xconv6')
            x = gated_conv(x, 4*cnum, 3, rate=2, name='xconv7_atrous')
            x = gated_conv(x, 4*cnum, 3, rate=4, name='xconv8_atrous')
            x = gated_conv(x, 4*cnum, 3, rate=8, name='xconv9_atrous')
            x = gated_conv(x, 4*cnum, 3, rate=16, name='xconv10_atrous')
            x_hallu = x
            # attention branch
            x = gated_conv(xnow, cnum, 5, 1, name='pmconv1')
            x = gated_conv(x, cnum, 3, 2, name='pmconv2_downsample')
            x = gated_conv(x, 2*cnum, 3, 1, name='pmconv3')
            x = gated_conv(x, 4*cnum, 3, 2, name='pmconv4_downsample')
            x = gated_conv(x, 4*cnum, 3, 1, name='pmconv5')
            x = gated_conv(x, 4*cnum, 3, 1, name='pmconv6',
                         activation=tf.nn.relu)
            x, offset_flow = contextual_attention(x, x, mask_s, 3, 1, rate=2)
            x = gated_conv(x, 4*cnum, 3, 1, name='pmconv9')
            x = gated_conv(x, 4*cnum, 3, 1, name='pmconv10')
            pm = x
            x = tf.concat([x_hallu, pm], axis=3)

            x = gated_conv(x, 4*cnum, 3, 1, name='allconv11')
            x = gated_conv(x, 4*cnum, 3, 1, name='allconv12')
            x = gated_deconv(x, 2*cnum, name='allconv13_upsample')
            x = gated_conv(x, 2*cnum, 3, 1, name='allconv14')
            x = gated_deconv(x, cnum, name='allconv15_upsample')
            x = gated_conv(x, cnum//2, 3, 1, name='allconv16')
            x = gen_conv(x, 3, 3, 1, activation=None, name='allconv17')
            #   ~~~ last layer: no activation: not gated but just conv
            x_stage2 = tf.clip_by_value(x, -1., 1.)
        return x_stage1, x_stage2, offset_flow
コード例 #7
0
    def build_inpaint_net(self,
                          x,
                          mask,
                          config=None,
                          reuse=False,
                          training=True,
                          padding='SAME',
                          name='inpaint_net'):
        """Inpaint network.

        Args:
            x: incomplete image, [-1, 1]
            mask: mask region {0, 1}
        Returns:
            [-1, 1] as predicted image
        """
        xin = x
        offset_flow = None
        ones_x = tf.ones_like(x)[:, :, :, 0:1]
        x = tf.concat([x, ones_x, ones_x * mask], axis=3)

        # two stage network
        cnum = 32
        with tf.variable_scope(name, reuse=reuse), \
                arg_scope([gen_conv, gen_deconv],
                          training=training, padding=padding):
            # stage1
            # cnum表示输出的维数
            x = gen_conv(x, cnum, 5, 1, name='conv1')
            x = gen_conv(x, 2 * cnum, 3, 2, name='conv2_downsample')
            x = gen_conv(x, 2 * cnum, 3, 1, name='conv3')
            x = gen_conv(x, 4 * cnum, 3, 2, name='conv4_downsample')
            x = gen_conv(x, 4 * cnum, 3, 1, name='conv5')
            # x = gen_conv(x, 4*cnum, 3, 1, name='conv6')
            mask_s = resize_mask_like(mask, x)

            # 并联跨步卷积 start
            x_first = x
            x = gen_conv(x_first, 4 * cnum, 3, rate=1, name='conv7_atrous_top')
            x = gen_conv(x, 4 * cnum, 3, rate=2, name='conv8_atrous_top')
            x = gen_conv(x, 4 * cnum, 3, rate=4, name='conv9_atrous_top')
            x = gen_conv(x, 4 * cnum, 3, rate=8, name='conv10_atrous_top')
            first_top = x
            x = gen_conv(x_first,
                         4 * cnum,
                         3,
                         rate=2,
                         name='conv7_atrous_bottom')
            x = gen_conv(x, 4 * cnum, 3, rate=4, name='conv8_atrous_bottom')
            x = gen_conv(x, 4 * cnum, 3, rate=8, name='conv9_atrous_bottom')
            x = gen_conv(x, 4 * cnum, 3, rate=16, name='conv10_atrous_bottom')
            first_bottom = x
            x = tf.concat([first_top, first_bottom], 3)
            # 并联跨步卷积 end

            x = gen_conv(x, 4 * cnum, 3, 1, name='conv11')
            x = gen_conv(x, 4 * cnum, 3, 1, name='conv12')
            x = gen_deconv(x, 2 * cnum, name='conv13_upsample')
            x = gen_conv(x, 2 * cnum, 3, 1, name='conv14')
            x = gen_deconv(x, cnum, name='conv15_upsample')
            x = gen_conv(x, cnum // 2, 3, 1, name='conv16')
            x = gen_conv(x, 3, 3, 1, activation=None, name='conv17')
            x = tf.clip_by_value(x, -1., 1.)
            x_stage1 = x
            # return x_stage1, None, None

            # stage2, paste result as input
            # x = tf.stop_gradient(x)
            x = x * mask + xin * (1. - mask)
            x.set_shape(xin.get_shape().as_list())
            # conv branch
            xnow = tf.concat([x, ones_x, ones_x * mask], axis=3)
            x = gen_conv(xnow, cnum, 5, 1, name='xconv1')
            x = gen_conv(x, cnum, 3, 2, name='xconv2_downsample')
            x = gen_conv(x, 2 * cnum, 3, 1, name='xconv3')
            x = gen_conv(x, 2 * cnum, 3, 2, name='xconv4_downsample')
            x = gen_conv(x, 4 * cnum, 3, 1, name='xconv5')
            x = gen_conv(x, 4 * cnum, 3, 1, name='xconv6')
            x = gen_conv(x, 4 * cnum, 3, rate=2, name='xconv7_atrous')
            x = gen_conv(x, 4 * cnum, 3, rate=4, name='xconv8_atrous')
            x = gen_conv(x, 4 * cnum, 3, rate=8, name='xconv9_atrous')
            x = gen_conv(x, 4 * cnum, 3, rate=16, name='xconv10_atrous')
            x_hallu = x
            # attention branch
            x = gen_conv(xnow, cnum, 5, 1, name='pmconv1')
            x = gen_conv(x, cnum, 3, 2, name='pmconv2_downsample')
            x = gen_conv(x, 2 * cnum, 3, 1, name='pmconv3')
            x = gen_conv(x, 4 * cnum, 3, 2, name='pmconv4_downsample')
            x = gen_conv(x, 4 * cnum, 3, 1, name='pmconv5')
            x = gen_conv(x,
                         4 * cnum,
                         3,
                         1,
                         name='pmconv6',
                         activation=tf.nn.relu)
            x, offset_flow = contextual_attention(x, x, mask_s, 3, 1, rate=2)
            x = gen_conv(x, 4 * cnum, 3, 1, name='pmconv9')
            x = gen_conv(x, 4 * cnum, 3, 1, name='pmconv10')
            pm = x
            x = tf.concat([x_hallu, pm], axis=3)

            x = gen_conv(x, 4 * cnum, 3, 1, name='allconv11')
            x = gen_conv(x, 4 * cnum, 3, 1, name='allconv12')
            x = gen_deconv(x, 2 * cnum, name='allconv13_upsample')
            x = gen_conv(x, 2 * cnum, 3, 1, name='allconv14')
            x = gen_deconv(x, cnum, name='allconv15_upsample')
            x = gen_conv(x, cnum // 2, 3, 1, name='allconv16')
            x = gen_conv(x, 3, 3, 1, activation=None, name='allconv17')
            x_stage2 = tf.clip_by_value(x, -1., 1.)
        return x_stage1, x_stage2, offset_flow
コード例 #8
0
    def build_inpaint_net(self,
                          x,
                          mask,
                          config=None,
                          reuse=False,
                          training=True,
                          padding='SAME',
                          name='inpaint_net',
                          exclusionmask=None):
        """Inpaint network.

        Args:
            x: incomplete image, [-1, 1]
            mask: mask region {0, 1}
        Returns:
            [-1, 1] as predicted image
        """
        multires = config.MULTIRES
        xin = x
        offset_flow = None
        ones_x = tf.ones_like(x)[:, :, :, 0:1]
        x = tf.concat([x, ones_x, ones_x * mask], axis=3)
        hasmask = False  #TODO:  #exclusionmask is not None
        if hasmask:
            exclusionmask = tf.cast(tf.less(exclusionmask[:, :, :, 0:1], 0.5),
                                    tf.float32)
            #x = tf.concat([x, exclusionmask], axis=3)
        use_gating = config.GATING

        # two stage network
        cnum = 24 if use_gating else 32
        with tf.variable_scope(name, reuse=reuse), \
                arg_scope([gen_conv, gen_deconv],
                          training=training, padding=padding):
            # stage1
            x = gen_conv(x, cnum, 5, 1, name='conv1', gating=use_gating)
            x = gen_conv(x,
                         2 * cnum,
                         3,
                         2,
                         name='conv2_downsample',
                         gating=use_gating)
            x = gen_conv(x, 2 * cnum, 3, 1, name='conv3', gating=use_gating)
            x = gen_conv(x,
                         4 * cnum,
                         3,
                         2,
                         name='conv4_downsample',
                         gating=use_gating)
            x = gen_conv(x, 4 * cnum, 3, 1, name='conv5', gating=use_gating)
            x = gen_conv(x, 4 * cnum, 3, 1, name='conv6', gating=use_gating)
            mask_s = resize_mask_like(mask, x)
            x = gen_conv(x,
                         4 * cnum,
                         3,
                         rate=2,
                         name='conv7_atrous',
                         gating=use_gating)
            x = gen_conv(x,
                         4 * cnum,
                         3,
                         rate=4,
                         name='conv8_atrous',
                         gating=use_gating)
            x = gen_conv(x,
                         4 * cnum,
                         3,
                         rate=8,
                         name='conv9_atrous',
                         gating=use_gating)
            x = gen_conv(x,
                         4 * cnum,
                         3,
                         rate=16,
                         name='conv10_atrous',
                         gating=use_gating)
            x = gen_conv(x, 4 * cnum, 3, 1, name='conv11', gating=use_gating)
            x = gen_conv(x, 4 * cnum, 3, 1, name='conv12', gating=use_gating)
            x = gen_deconv(x,
                           2 * cnum,
                           name='conv13_upsample',
                           gating=use_gating)
            x = gen_conv(x, 2 * cnum, 3, 1, name='conv14', gating=use_gating)
            x = gen_deconv(x, cnum, name='conv15_upsample', gating=use_gating)
            x = gen_conv(x, cnum // 2, 3, 1, name='conv16', gating=use_gating)
            x = gen_conv(x, 3, 3, 1, activation=None, name='conv17')
            x = tf.clip_by_value(x, -1., 1.)
            x_stage1 = x
            # return x_stage1, None, None

            # stage2, paste result as input
            # x = tf.stop_gradient(x)
            x = x * mask + xin * (1. - mask)
            x.set_shape(xin.get_shape().as_list())
            # conv branch
            xnow = tf.concat([x, ones_x, ones_x * mask], axis=3)
            #if hasmask:
            #    xnow = tf.concat([xnow, exclusionmask], axis=3)
            x = gen_conv(xnow, cnum, 5, 1, name='xconv1', gating=use_gating)
            x = gen_conv(x,
                         cnum,
                         3,
                         2,
                         name='xconv2_downsample',
                         gating=use_gating)
            x = gen_conv(x, 2 * cnum, 3, 1, name='xconv3', gating=use_gating)
            x = gen_conv(x,
                         2 * cnum,
                         3,
                         2,
                         name='xconv4_downsample',
                         gating=use_gating)
            x = gen_conv(x, 4 * cnum, 3, 1, name='xconv5', gating=use_gating)
            x = gen_conv(x, 4 * cnum, 3, 1, name='xconv6', gating=use_gating)
            x = gen_conv(x,
                         4 * cnum,
                         3,
                         rate=2,
                         name='xconv7_atrous',
                         gating=use_gating)
            x = gen_conv(x,
                         4 * cnum,
                         3,
                         rate=4,
                         name='xconv8_atrous',
                         gating=use_gating)
            x = gen_conv(x,
                         4 * cnum,
                         3,
                         rate=8,
                         name='xconv9_atrous',
                         gating=use_gating)
            x = gen_conv(x,
                         4 * cnum,
                         3,
                         rate=16,
                         name='xconv10_atrous',
                         gating=use_gating)
            x_hallu = x
            # attention branch
            x = gen_conv(xnow, cnum, 5, 1, name='pmconv1', gating=use_gating)
            x = gen_conv(x,
                         cnum,
                         3,
                         2,
                         name='pmconv2_downsample',
                         gating=use_gating)
            x = gen_conv(x, 2 * cnum, 3, 1, name='pmconv3', gating=use_gating)
            x = gen_conv(x,
                         4 * cnum,
                         3,
                         2,
                         name='pmconv4_downsample',
                         gating=use_gating)
            x = gen_conv(x, 4 * cnum, 3, 1, name='pmconv5', gating=use_gating)
            x = gen_conv(x,
                         4 * cnum,
                         3,
                         1,
                         name='pmconv6',
                         activation=tf.nn.relu,
                         gating=use_gating)
            flows = []
            use_attentionmask = hasmask and config.ATTENTION_MASK
            if use_attentionmask:
                ex_mask_s = resize_mask_like(exclusionmask, x)
            if multires:  #scale down feature map, run contextual attention, scale up and paste inpainted region into original feature map
                logger.info('USING MULTIRES')
                logger.info('original x shape: ' + str(x.shape))
                logger.info('original mask shape: ' + str(mask_s.shape))
                x_multi = [x]
                mask_multi = [mask_s]
                if use_attentionmask:
                    exclusion_mask_multi = [ex_mask_s]
                for i in range(config.LEVELS - 1):
                    #x = gen_conv(x, 4*cnum, 3, 2, name='pyramid_downsample_'+str(i+1))
                    x = resize(x, scale=0.5)
                    x_multi.append(x)
                    mask_multi.append(resize_mask_like(mask_s, x))
                    if use_attentionmask:
                        exclusion_mask_multi.append(
                            resize_mask_like(ex_mask_s, x))
                        logger.info('exclusionmask shape: ' +
                                    str(exclusion_mask_multi[i + 1].shape))
                    logger.info('x shape: ' + str(x_multi[i + 1].shape))
                    logger.info('mask shape: ' + str(mask_multi[i + 1].shape))
                x_multi.reverse()
                mask_multi.reverse()
                if use_attentionmask:
                    exclusion_mask_multi.reverse()
                for i in range(config.LEVELS - 1):
                    if use_attentionmask:
                        totalmask = mask_multi[i] + exclusion_mask_multi[i]
                        print('total mask shape:', totalmask.shape)
                    else:
                        totalmask = tf.tile(mask_multi[i],
                                            [config.BATCH_SIZE, 1, 1, 1])
                    x, flow = contextual_attention(x,
                                                   x,
                                                   totalmask,
                                                   ksize=config.PATCH_KSIZE,
                                                   stride=config.PATCH_STRIDE,
                                                   rate=config.PATCH_RATE)
                    #x, flow = contextual_attention(x, x, mask_multi[i], ksize=3, stride=1, rate=1)
                    flows.append(flow)
                    x = resize(
                        x, scale=2
                    )  #TODO: look into using deconv instead of just upsampling
                    x = x * mask_multi[i + 1] + x_multi[i + 1] * (
                        1. - mask_multi[i + 1])
                    logger.info('upsampled x shape: ' + str(x.shape))

            x, offset_flow = contextual_attention(
                x,
                x,
                tf.tile(mask_s, [config.BATCH_SIZE, 1, 1, 1])
                if not use_attentionmask else mask_s + ex_mask_s,
                ksize=config.PATCH_KSIZE,
                stride=config.PATCH_STRIDE,
                rate=config.PATCH_RATE)
            flows.append(offset_flow)
            x = gen_conv(x, 4 * cnum, 3, 1, name='pmconv9', gating=use_gating)
            x = gen_conv(x, 4 * cnum, 3, 1, name='pmconv10', gating=use_gating)
            pm = x
            x = tf.concat([x_hallu, pm], axis=3)  #join branches together

            x = gen_conv(x,
                         4 * cnum,
                         3,
                         1,
                         name='allconv11',
                         gating=use_gating)
            x = gen_conv(x,
                         4 * cnum,
                         3,
                         1,
                         name='allconv12',
                         gating=use_gating)
            x = gen_deconv(x,
                           2 * cnum,
                           name='allconv13_upsample',
                           gating=use_gating)
            x = gen_conv(x,
                         2 * cnum,
                         3,
                         1,
                         name='allconv14',
                         gating=use_gating)
            x = gen_deconv(x,
                           cnum,
                           name='allconv15_upsample',
                           gating=use_gating)
            x = gen_conv(x,
                         cnum // 2,
                         3,
                         1,
                         name='allconv16',
                         gating=use_gating)
            x = gen_conv(x, 3, 3, 1, activation=None, name='allconv17')
            x_stage2 = tf.clip_by_value(x, -1., 1.)
        return x_stage1, x_stage2, flows
コード例 #9
0
    def build_inpaint_net(self,
                          x,
                          mask,
                          config=None,
                          reuse=False,
                          training=True,
                          padding='SAME',
                          name='inpaint_net'):
        """Inpaint network.

        Args:
            x: incomplete image, [-1, 1]
            mask: mask region {0, 1}
        Returns:
            [-1, 1] as predicted image
        """
        xin = x
        offset_flow = None
        ones_x = tf.ones_like(x)[:, :, :, 0:1]
        x = tf.concat([x, ones_x, ones_x * mask], axis=3)

        if (config.GATED_CONVOLUTIONS):
            used_conv = gen_gated_conv
        else:
            used_conv = gen_conv
        # two stage network
        cnum = 32
        with tf.variable_scope(name, reuse=reuse), \
             arg_scope([used_conv, gen_deconv],
                       training=training, padding=padding):
            # stage1
            x = used_conv(x, cnum, 5, 1, name='conv1')
            x = used_conv(x, 2 * cnum, 3, 2, name='conv2_downsample')
            x = used_conv(x, 2 * cnum, 3, 1, name='conv3')
            x = used_conv(x, 4 * cnum, 3, 2, name='conv4_downsample')
            x = used_conv(x, 4 * cnum, 3, 1, name='conv5')
            x = used_conv(x, 4 * cnum, 3, 1, name='conv6')
            mask_s = resize_mask_like(mask, x)
            x = used_conv(x, 4 * cnum, 3, rate=2, name='conv7_atrous')
            x = used_conv(x, 4 * cnum, 3, rate=4, name='conv8_atrous')
            x = used_conv(x, 4 * cnum, 3, rate=8, name='conv9_atrous')
            x = used_conv(x, 4 * cnum, 3, rate=16, name='conv10_atrous')
            x = used_conv(x, 4 * cnum, 3, 1, name='conv11')
            x = used_conv(x, 4 * cnum, 3, 1, name='conv12')
            x = gen_deconv(x, 2 * cnum, name='conv13_upsample')
            x = used_conv(x, 2 * cnum, 3, 1, name='conv14')
            x = gen_deconv(x, cnum, name='conv15_upsample')
            x = used_conv(x, cnum // 2, 3, 1, name='conv16')
            x = used_conv(x, 3, 3, 1, activation=None, name='conv17')
            x = tf.clip_by_value(x, -1., 1.)
            x_stage1 = x
            # return x_stage1, None, None

            # stage2, paste result as input
            # x = tf.stop_gradient(x)
            x = x * mask + xin * (1. - mask)
            x.set_shape(xin.get_shape().as_list())
            # conv branch
            xnow = tf.concat([x, ones_x, ones_x * mask], axis=3)
            x = used_conv(xnow, cnum, 5, 1, name='xconv1')
            x = used_conv(x, cnum, 3, 2, name='xconv2_downsample')
            x = used_conv(x, 2 * cnum, 3, 1, name='xconv3')
            x = used_conv(x, 2 * cnum, 3, 2, name='xconv4_downsample')
            x = used_conv(x, 4 * cnum, 3, 1, name='xconv5')
            x = used_conv(x, 4 * cnum, 3, 1, name='xconv6')
            x = used_conv(x, 4 * cnum, 3, rate=2, name='xconv7_atrous')
            x = used_conv(x, 4 * cnum, 3, rate=4, name='xconv8_atrous')
            x = used_conv(x, 4 * cnum, 3, rate=8, name='xconv9_atrous')
            x = used_conv(x, 4 * cnum, 3, rate=16, name='xconv10_atrous')

            if (config.NO_HALLUC):
                # turn off hallucination pathway for ablation study
                zeros = tf.zeros(shape=x.shape, dtype=tf.float32, name=None)
                x_hallu = tf.multiply(x, zeros)  # bit-wise multiplication

            else:
                x_hallu = x

            # attention branch
            x = used_conv(xnow, cnum, 5, 1, name='pmconv1')
            x = used_conv(x, cnum, 3, 2, name='pmconv2_downsample')
            x = used_conv(x, 2 * cnum, 3, 1, name='pmconv3')
            x = used_conv(x, 4 * cnum, 3, 2, name='pmconv4_downsample')
            x = used_conv(x, 4 * cnum, 3, 1, name='pmconv5')
            x = used_conv(x,
                          4 * cnum,
                          3,
                          1,
                          name='pmconv6',
                          activation=tf.nn.relu)
            x, offset_flow = contextual_attention(x, x, mask_s, 3, 1, rate=2)
            x = used_conv(x, 4 * cnum, 3, 1, name='pmconv9')
            x = used_conv(x, 4 * cnum, 3, 1, name='pmconv10')

            if (config.NO_ATTENTION):
                # turn off attention pathway for ablation study
                zeros = tf.zeros(shape=x.shape, dtype=tf.float32, name=None)
                pm = tf.multiply(x, zeros)

            else:
                pm = x

            x = tf.concat([x_hallu, pm], axis=3)
            x = used_conv(x, 4 * cnum, 3, 1, name='allconv11')
            x = used_conv(x, 4 * cnum, 3, 1, name='allconv12')
            x = gen_deconv(x, 2 * cnum, name='allconv13_upsample')
            x = used_conv(x, 2 * cnum, 3, 1, name='allconv14')
            x = gen_deconv(x, cnum, name='allconv15_upsample')
            x = used_conv(x, cnum // 2, 3, 1, name='allconv16')
            x = used_conv(x, 3, 3, 1, activation=None, name='allconv17')
            x_stage2 = tf.clip_by_value(x, -1., 1.)
        return x_stage1, x_stage2, offset_flow
コード例 #10
0
    def __call__(self, x, mask, return_offset=False):
        xin = x
        if mask.shape[1] == 1:  # no edge image
            mask = F.concat([mask, self.xp.zeros_like(x[:, :1])])

        x = F.concat([x, mask], axis=1)

        x = self.conv1(x)
        x = self.conv2_downsample(x)
        x = self.conv3(x)
        x = self.conv4_downsample(x)
        x = self.conv5(x)
        x = self.conv6(x)
        mask_s = resize_mask_like(mask[:, :1], x)
        x = self.conv7_atrous(x)
        x = self.conv8_atrous(x)
        x = self.conv9_atrous(x)
        x = self.conv10_atrous(x)
        x = self.conv11(x)
        x = self.conv12(x)
        x = self.conv13_upsample(x)
        x = self.conv14(x)
        x = self.conv15_upsample(x)
        x = self.conv16(x)
        x = self.conv17(x)
        x = F.clip(x, -1., 1.)
        x_stage1 = x

        # stage2, paste result as input
        x = x * mask[:, :1] + xin * (1. - mask[:, :1])
        # conv branch
        xnow = F.concat([x, mask], axis=1)
        x = self.xconv1(xnow)
        x = self.xconv2_downsample(x)
        x = self.xconv3(x)
        x = self.xconv4_downsample(x)
        x = self.xconv5(x)
        x = self.xconv6(x)
        x = self.xconv7_atrous(x)
        x = self.xconv8_atrous(x)
        x = self.xconv9_atrous(x)
        x = self.xconv10_atrous(x)
        x_hallu = x
        # attention branch
        x = self.pmconv1(xnow)
        x = self.pmconv2_downsample(x)
        x = self.pmconv3(x)
        x = self.pmconv4_downsample(x)
        x = self.pmconv5(x)
        x = self.pmconv6(x)

        x, offset_flow = contextual_attention(x,
                                              x,
                                              mask_s,
                                              3,
                                              1,
                                              rate=2,
                                              return_flow=return_offset)
        x = self.pmconv9(x)
        x = self.pmconv10(x)

        # pm = x
        x = F.concat([x_hallu, x], axis=1)

        x = self.allconv11(x)
        x = self.allconv12(x)
        x = self.allconv13_upsample(x)
        x = self.allconv14(x)
        x = self.allconv15_upsample(x)
        x = self.allconv16(x)
        x = self.allconv17(x)

        x_stage2 = F.clip(x, -1., 1.)
        return x_stage1, x_stage2, offset_flow
コード例 #11
0
    def build_inpaint_net(self,
                          x,
                          mask,
                          config=None,
                          reuse=False,
                          training=True,
                          padding='SAME',
                          name='inpaint_net'):
        """Inpaint network.

        Args:
            x: incomplete image, [-1, 1]
            mask: mask region {0, 1}
        Returns:
            [-1, 1] as predicted image
        """
        xin = x
        offset_flow = None
        ones_x = tf.ones_like(x)[:, :, :, 0:1]

        xsurf_mask = tf.clip_by_value(((x + 1.) * 127.5), 0., 1.)

        x = tf.concat([x, ones_x, ones_x * mask], axis=3)

        # two stage network
        cnum = 32
        with tf.variable_scope(name, reuse=tf.AUTO_REUSE), \
                arg_scope([gen_conv, gen_deconv],
                          training=training, padding=padding):
            # stage1
            #####################################################
            #Tensorflow Map Function (RGB + Disparity) = Input	#
            #-> Input = map_function(RGB, Disparity)		#
            #####################################################

            x = gen_conv(x, cnum, 5, 1, name='conv1')
            x = gen_conv(x, 2 * cnum, 3, 2, name='conv2_downsample')
            x = gen_conv(x, 2 * cnum, 3, 1, name='conv3')
            x = gen_conv(x, 4 * cnum, 3, 2, name='conv4_downsample')
            x = gen_conv(x, 4 * cnum, 3, 1, name='conv5')
            x = gen_conv(x, 4 * cnum, 3, 1, name='conv6')
            mask_s = resize_mask_like(mask, x)
            x = gen_conv(x, 4 * cnum, 3, rate=2, name='conv7_atrous')
            x = gen_conv(x, 4 * cnum, 3, rate=4, name='conv8_atrous')
            x = gen_conv(x, 4 * cnum, 3, rate=8, name='conv9_atrous')
            x = gen_conv(x, 4 * cnum, 3, rate=16, name='conv10_atrous')
            x = gen_conv(x, 4 * cnum, 3, 1, name='conv11')
            x = gen_conv(x, 4 * cnum, 3, 1, name='conv12')
            x = gen_deconv(x, 2 * cnum, name='conv13_upsample')
            x = gen_conv(x, 2 * cnum, 3, 1, name='conv14')
            x = gen_deconv(x, cnum, name='conv15_upsample')
            x = gen_conv(x, cnum // 2, 3, 1, name='conv16')
            x = gen_conv(x, 3, 3, 1, activation=None, name='conv17')

            x = tf.clip_by_value(x, -1., 1.)
            x_stage1 = x
            # return x_stage1, None, None

            # stage2, paste result as input
            # x = tf.stop_gradient(x)
            #cnum = 32
            x = x * mask + xin * (1. - mask)
            x.set_shape(xin.get_shape().as_list())
            # conv branch
            xnow = tf.concat([x, ones_x, ones_x * mask], axis=3)
            x_surf = surf_conv((tf.image.rgb_to_grayscale(x) + 1.) * 127.5)
            xnow_surf = tf.concat(
                [tf.image.rgb_to_grayscale(x), x_surf, ones_x, ones_x * mask],
                axis=3)

            x = gen_conv(xnow, cnum, 5, 1, name='xconv1')
            x = gen_conv(x, cnum, 3, 2, name='xconv2_downsample')
            x = gen_conv(x, 2 * cnum, 3, 1, name='xconv3')
            x = gen_conv(x, 2 * cnum, 3, 2, name='xconv4_downsample')
            x = gen_conv(x, 4 * cnum, 3, 1, name='xconv5')
            x = gen_conv(x, 4 * cnum, 3, 1, name='xconv6')
            x = gen_conv(x, 4 * cnum, 3, rate=2, name='xconv7_atrous')
            x = gen_conv(x, 4 * cnum, 3, rate=4, name='xconv8_atrous')
            x = gen_conv(x, 4 * cnum, 3, rate=8, name='xconv9_atrous')
            x = gen_conv(x, 4 * cnum, 3, rate=16, name='xconv10_atrous')
            x_hallu = x
            # attention branch
            x = gen_conv(xnow_surf, cnum, 5, 1, name='pmconv1')
            x = gen_conv(x, cnum, 3, 2, name='pmconv2_downsample')
            x = gen_conv(x, 2 * cnum, 3, 1, name='pmconv3')
            x = gen_conv(x, 4 * cnum, 3, 2, name='pmconv4_downsample')
            x = gen_conv(x, 4 * cnum, 3, 1, name='pmconv5')
            x = gen_conv(x,
                         4 * cnum,
                         3,
                         1,
                         name='pmconv6',
                         activation=tf.nn.relu)

            x, offset_flow = surface_attention(x, x, mask_s, 3, 1, rate=2)

            x = gen_conv(x, 4 * cnum, 3, 1, name='pmconv9')
            x = gen_conv(x, 4 * cnum, 3, 1, name='pmconv10')
            pm = x
            x = tf.concat([x_hallu, pm], axis=3)

            x = gen_conv(x, 4 * cnum, 3, 1, name='allconv11')
            x = gen_conv(x, 4 * cnum, 3, 1, name='allconv12')
            x = gen_deconv(x, 2 * cnum, name='allconv13_upsample')
            x = gen_conv(x, 2 * cnum, 3, 1, name='allconv14')
            x = gen_deconv(x, cnum, name='allconv15_upsample')
            x = gen_conv(x, cnum // 2, 3, 1, name='allconv16')
            x = gen_conv(x, 3, 3, 1, activation=None, name='allconv17')
            x_stage2 = tf.clip_by_value(x, -1., 1.)

        return x_stage1, x_stage2, offset_flow
    def build_inpaint_net(self,
                          x,
                          x2,
                          mask,
                          config=None,
                          reuse=False,
                          training=True,
                          padding='SAME',
                          name='inpaint_net'):
        """Inpaint network.

        Args:
            x: incomplete image, [-1, 1]
            mask: mask region {0, 1}
        Returns:
            [-1, 1] as predicted image
        """
        xin = x
        ############
        x2in = x2

        ones_x2 = tf.ones_like(x)[:, :, :, 0:1]  #no need
        x2 = tf.concat([x, ones_x2, ones_x2], axis=3)  #no need
        ##############

        #####1st x_input [Dimension(2), Dimension(256), Dimension(256), Dimension(3)]

        offset_flow = None
        ones_x = tf.ones_like(x)[:, :, :, 0:1]
        print("ones_x size", (ones_x.shape))
        x = tf.concat([x, ones_x, ones_x * mask], axis=3)
        ###x_input [Dimension(2), Dimension(256), Dimension(256), Dimension(5)]
        mask2 = 1 - mask
        tf.print(mask2)
        print("x_input", (x.shape.dims))
        # two stage network
        cnum = 32
        with tf.variable_scope(name, reuse=reuse), \
                arg_scope([gen_conv, gen_deconv],
                         training=training, padding=padding):
            '''
            # stage1
            x = gen_conv(x, cnum, 5, 1, name='conv1')
            x = gen_conv(x, 2*cnum, 3, 2, name='conv2_downsample')
            x = gen_conv(x, 2*cnum, 3, 1, name='conv3')
            x = gen_conv(x, 4*cnum, 3, 2, name='conv4_downsample')
            x = gen_conv(x, 4*cnum, 3, 1, name='conv5')
            x = gen_conv(x, 4*cnum, 3, 1, name='conv6')
            print('*******************\n***************build model  network***************\n************************')
            print("\nx_for mask_s",(x.shape.dims))
            ###x_for mask_s [Dimension(2), Dimension(64), Dimension(64), Dimension(128)]
            mask_s = resize_mask_like(mask, x)
            
            mask_n=resize_mask_like(mask2, x)
            print("size of mask_s",mask_s.shape)
            
            ###size of mask_s (1, 64, 64, 1)
            x = gen_conv(x, 4*cnum, 3, rate=2, name='conv7_atrous')
            x = gen_conv(x, 4*cnum, 3, rate=4, name='conv8_atrous')
            x = gen_conv(x, 4*cnum, 3, rate=8, name='conv9_atrous')
            x = gen_conv(x, 4*cnum, 3, rate=16, name='conv10_atrous')
            x = gen_conv(x, 4*cnum, 3, 1, name='conv11')
            x = gen_conv(x, 4*cnum, 3, 1, name='conv12')
            x = gen_deconv(x, 2*cnum, name='conv13_upsample')
            x = gen_conv(x, 2*cnum, 3, 1, name='conv14')
            x = gen_deconv(x, cnum, name='conv15_upsample')
            x = gen_conv(x, cnum//2, 3, 1, name='conv16')
            x = gen_conv(x, 3, 3, 1, activation=None, name='conv17')
            x = tf.clip_by_value(x, -1., 1.)
            x_stage1 = x
            '''

            # return x_stage1, None, None
            ###x_stage1 [Dimension(2), Dimension(256), Dimension(256), Dimension(3)]
            #print("x_stage1",x.shape.dims)
            # stage2, paste result as input
            # x = tf.stop_gradient(x)

            x_stage1 = xin
            x = x2in * mask + xin * (1. - mask)
            x_stage1 = x
            x.set_shape(xin.get_shape().as_list())
            print("x input_stage2", x.shape.dims)
            #x input_stage2 [Dimension(2), Dimension(256), Dimension(256), Dimension(3)]
            # conv branch
            xnow = tf.concat([x, ones_x, ones_x * mask], axis=3)
            print("xnow input_stage2", xnow.shape)
            x = gen_conv(xnow, cnum, 5, 1, name='xconv1')
            x = gen_conv(x, cnum, 3, 2, name='xconv2_downsample')
            x = gen_conv(x, 2 * cnum, 3, 1, name='xconv3')
            x = gen_conv(x, 2 * cnum, 3, 2, name='xconv4_downsample')
            x = gen_conv(x, 4 * cnum, 3, 1, name='xconv5')
            x = gen_conv(x, 4 * cnum, 3, 1, name='xconv6')
            x = gen_conv(x, 4 * cnum, 3, rate=2, name='xconv7_atrous')
            x = gen_conv(x, 4 * cnum, 3, rate=4, name='xconv8_atrous')
            x = gen_conv(x, 4 * cnum, 3, rate=8, name='xconv9_atrous')
            x = gen_conv(x, 4 * cnum, 3, rate=16, name='xconv10_atrous')
            x_hallu = x
            mask_s = resize_mask_like(mask, x)
            mask_n = resize_mask_like(mask2, x)
            print("Here is the mASK_S", mask_s)
            print("x_hallu or conv layer output", (x.shape.dims))
            ####x_hallu or conv layer output [Dimension(2), Dimension(64), Dimension(64), Dimension(128)]
            # attention branch
            #attention layer has the same input as conv layer which is xnow
            x = gen_conv(xnow, cnum, 5, 1, name='pmconv1')
            x = gen_conv(x, cnum, 3, 2, name='pmconv2_downsample')
            x = gen_conv(x, 2 * cnum, 3, 1, name='pmconv3')
            x = gen_conv(x, 4 * cnum, 3, 2, name='pmconv4_downsample')
            x = gen_conv(x, 4 * cnum, 3, 1, name='pmconv5')
            x = gen_conv(x,
                         4 * cnum,
                         3,
                         1,
                         name='pmconv6',
                         activation=tf.nn.relu)
            print("x input to context layer", x.shape.dims)
            #x input to context layer [Dimension(2), Dimension(64), Dimension(64), Dimension(128)]

            f = x

            x, offset_flow = contextual_attention(x, x, mask_s, 3, 1, rate=2)
            print("x output to context layer", x.shape.dims)
            #x output to context layer [Dimension(2), Dimension(64), Dimension(64), Dimension(128)]

            x = gen_conv(x, 4 * cnum, 3, 1, name='pmconv9')
            x = gen_conv(x, 4 * cnum, 3, 1, name='pmconv10')
            pm = x
            '''
            #####################2nd attention layer#########################
                        #attention layer has the same input as conv layer which is xnow
            x = gen_conv(x2, cnum, 5, 1, name='naconv1')
            x = gen_conv(x, cnum, 3, 2, name='naconv2_downsample')
            x = gen_conv(x, 2*cnum, 3, 1, name='naconv3')
            x = gen_conv(x, 4*cnum, 3, 2, name='naconv4_downsample')
            x = gen_conv(x, 4*cnum, 3, 1, name='naconv5')
            x = gen_conv(x, 4*cnum, 3, 1, name='naconv6',
                         activation=tf.nn.relu)
            print("new layer:x input to context layer",x.shape.dims)
            #x input to context layer [Dimension(2), Dimension(64), Dimension(64), Dimension(128)]
            x, offset_flow = contextual_attention(f, x, mask_n, 3, 1, rate=2)
            print("new layer:x output to context layer",x.shape.dims)
            #x output to context layer [Dimension(2), Dimension(64), Dimension(64), Dimension(128)]

            x = gen_conv(x, 4*cnum, 3, 1, name='naconv9')
            x = gen_conv(x, 4*cnum, 3, 1, name='naconv10')
            na = x
            '''
            #################################
            #x = tf.concat([x_hallu, pm,na], axis=3)
            x = tf.concat([x_hallu, pm], axis=3)
            x = gen_conv(x, 4 * cnum, 3, 1, name='allconv11')
            x = gen_conv(x, 4 * cnum, 3, 1, name='allconv12')
            x = gen_deconv(x, 2 * cnum, name='allconv13_upsample')
            x = gen_conv(x, 2 * cnum, 3, 1, name='allconv14')
            x = gen_deconv(x, cnum, name='allconv15_upsample')
            x = gen_conv(x, cnum // 2, 3, 1, name='allconv16')
            x = gen_conv(x, 3, 3, 1, activation=None, name='allconv17')
            x_stage2 = tf.clip_by_value(x, -1., 1.)
            print("x_stage2", (x.shape.dims))
            #x_stage2 [Dimension(2), Dimension(256), Dimension(256), Dimension(3)]
            print(
                '*******************\n***************build model  network***************\n************************'
            )
        return x_stage1, x_stage2, offset_flow
コード例 #13
0
    def build_inpaint_net(self, x, mask, reuse=False,
                          training=True, padding='SAME', name='inpaint_net'):
        """Inpaint network.

        Args:
            x: incomplete image, [-1, 1]
            mask: mask region {0, 1}
        Returns:
            [-1, 1] as predicted image
        """
		# padding='SAME' :自动填充0
        xin = x
        offset_flow = None
        ones_x = tf.ones_like(x)[:, :, :, 0:1]
        x = tf.concat([x, ones_x, ones_x*mask], axis=3) #拼接:在像素维拼接

        # two stage network
        cnum = 48
        with tf.variable_scope(name, reuse=reuse), \
                arg_scope([gen_conv, gen_deconv],
                          training=training, padding=padding):
            # stage1
            x = gen_conv(x, cnum, 5, 1, name='conv1')
            x = gen_conv(x, 2*cnum, 3, 1, name='conv2')
            conv2=x
            x = gen_conv(x, 4*cnum, 5, 2, name='conv3')
            x = gen_conv(x, 4*cnum, 3, 1, name='conv4')
         
            conv4=x
            x = gen_conv(x, 8*cnum, 3, 2, name='conv5')
            x = gen_conv(x, 8*cnum, 3, 1, name='conv6')
            conv6=x
            mask_s = resize_mask_like(mask, x)    #调整mask的形状 # 位置是否正确还有待确定
            x = gen_conv(x, 8*cnum, 3, 2, name='conv7')
            x = gen_conv(x, 8*cnum, 3, 1, name='conv8')
            conv8=x
            x = gen_conv(x, 16*cnum, 3, 2, name='conv9')
            x = gen_conv(x, 16*cnum, 3, 1, name='conv10')
            x = gen_deconv(x, 8*cnum, name='convd')
            

            
            convd=x
			#refinenet1
            x = tf.concat([convd, conv8], axis=3)
            x = gen_conv(x, 8*cnum, 3, 1, name='r1conv1')
            x = gen_conv(x, 8*cnum, 3, 1, name='r1conv2')
            x = gen_conv(x, 8*cnum, 3, 1, name='r1conv3')
            r1conv3=x
            x = gen_conv(x, 8*cnum, 3, 1, name='r1conv4')
            x = gen_conv(x, 8*cnum, 3, 1, name='r1conv5')
            x = tf.add(x,r1conv3)  # 残差连接 
            x = gen_conv(x, 8*cnum, 3, 1, name='r1conv6')
            x = gen_deconv(x, 4*cnum, name='r1convd')
            
            x, offset_flow = contextual_attention(x, x, mask_s, 3, 1, rate=2)
            
			#refinenet2
            x = tf.concat([x, conv6], axis=3)
            x = gen_conv(x, 4*cnum, 3, 1, name='r2conv1')
            x = gen_conv(x, 4*cnum, 3, 1, name='r2conv2')
            x = gen_conv(x, 4*cnum, 3, 1, name='r2conv3')
            
            r2conv3=x
            x = gen_conv(x, 4*cnum, 3, 1, name='r2conv4')
            x = gen_conv(x, 4*cnum, 3, 1, name='r2conv5')
            x = tf.add(x,r2conv3)  # 残差连接 
            x = gen_conv(x, 4*cnum, 3, 1, name='r2conv6')
            x = gen_deconv(x, 2*cnum, name='r2convd')
            
           
            
			#refinenet3
            x = tf.concat([x, conv4], axis=3)
            x = gen_conv(x, 2*cnum, 3, 1, name='r3conv1')
            x = gen_conv(x, 2*cnum, 3, 1, name='r3conv2')
            x = gen_conv(x, 2*cnum, 3, 1, name='r3conv3')
            r3conv3=x
            x = gen_conv(x, 2*cnum, 3, 1, name='r3conv4')
            x = gen_conv(x, 2*cnum, 3, 1, name='r3conv5')
            x = tf.add(x,r3conv3)  # 残差连接 
            x = gen_conv(x, 2*cnum, 3, 1, name='r3conv6')
            x = gen_deconv(x,cnum, name='r3convd')
            
            # stage2, paste result as input
            #x = x*mask + xin[:, :, :, 0:3]*(1.-mask)
            #x.set_shape(xin[:, :, :, 0:3].get_shape().as_list())
            # conv branch
            xnow = tf.concat([x, ones_x, ones_x*mask], axis=3)
            xnow = x
            x = gen_conv(xnow, cnum, 5, 1, name='1xconv1')
            x = gen_conv(x, cnum, 3, 2, name='1xconv2_downsample')
            x = gen_conv(x, 2*cnum, 3, 1, name='1xconv3')
            x = gen_conv(x, 2*cnum, 3, 2, name='1xconv4_downsample')
            x = gen_conv(x, 4*cnum, 3, 1, name='1xconv5')
            xconv5=x
            x = gen_conv(x, 4*cnum, 3, 1, name='1xconv6') 
            x = gen_conv(x, 4*cnum, 3, 1, name='1xconv7')
            x = tf.add(x,xconv5)  # 残差连接 
            x = gen_conv(x, 4*cnum, 3, 1, name='1xconv8')
            x = gen_conv(x, 4*cnum, 3, rate=2, name='1xconv9_atrous')
            x = gen_conv(x, 4*cnum, 3, rate=4, name='1xconv10_atrous')
            x = gen_conv(x, 4*cnum, 3, rate=8, name='1xconv11_atrous')
            x = gen_conv(x, 4*cnum, 3, rate=16, name='1xconv12_atrous')
            x_hallu = x
            # attention branch
            x = gen_conv(xnow, cnum, 5, 1, name='1pmconv1')
            x = gen_conv(x, cnum, 3, 2, name='1pmconv2_downsample')
            x = gen_conv(x, 2*cnum, 3, 1, name='1pmconv3')
            x = gen_conv(x, 4*cnum, 3, 2, name='1pmconv4_downsample')
            x = gen_conv(x, 4*cnum, 3, 1, name='1pmconv5')
            x = gen_conv(x, 4*cnum, 3, 1, name='1pmconv6',
                                activation=tf.nn.relu)
            x, offset_flow = contextual_attention(x, x, mask_s, 3, 1, rate=2)
            x = gen_conv(x, 4*cnum, 3, 1, name='1pmconv9')
            x = gen_conv(x, 4*cnum, 3, 1, name='1pmconv10')
            pm = x
            x = tf.concat([x_hallu, pm], axis=3)

            x = gen_conv(x, 4*cnum, 3, 1, name='1allconv11')
            x = gen_conv(x, 4*cnum, 3, 1, name='1allconv12')
            x = gen_deconv(x, 2*cnum, name='1allconv13_upsample')
            x = gen_conv(x, 2*cnum, 3, 1, name='1allconv14')
            x = gen_deconv(x, cnum, name='1allconv15_upsample')
#            x = gen_conv(x, cnum//2, 3, 1, name='allconv16')
#            x = gen_conv(x, 3, 3, 1, activation=None, name='allconv17')
            x = tf.nn.tanh(x)           
            
			#refinenet4
            x = tf.concat([x, conv2], axis=3)
            x = gen_conv(x, cnum, 3, 1, name='r4conv1')
            x = gen_conv(x, cnum, 3, 1, name='r4conv2')
            x = gen_conv(x, cnum, 3, 1, name='r4conv3')
            r4conv3=x
            x = gen_conv(x, cnum, 3, 1, name='r4conv4')
            x = gen_conv(x, cnum, 3, 1, name='r4conv5')
            x = tf.add(x,r4conv3)  # 残差连接 
            x = gen_conv(x, cnum, 3, 1, name='r4conv6')
            x = gen_conv(x, 3, 3, 1,activation=None, name='r4conv7')
            x = tf.nn.tanh(x)
            x_stage1 = x

            # stage2, paste result as input
            x = x*mask + xin[:, :, :, 0:3]*(1.-mask)
            x.set_shape(xin[:, :, :, 0:3].get_shape().as_list())
            # conv branch
            # xnow = tf.concat([x, ones_x, ones_x*mask], axis=3)
            xnow = x
            x = gen_conv(xnow, cnum, 5, 1, name='xconv1')
            x = gen_conv(x, cnum, 3, 2, name='xconv2_downsample')
            x = gen_conv(x, 2*cnum, 3, 1, name='xconv3')
            x = gen_conv(x, 2*cnum, 3, 2, name='xconv4_downsample')
            x = gen_conv(x, 4*cnum, 3, 1, name='xconv5')
            xconv5=x
            x = gen_conv(x, 4*cnum, 3, 1, name='xconv6') 
            x = gen_conv(x, 4*cnum, 3, 1, name='xconv7')
            x = tf.add(x,xconv5)  # 残差连接 
            x = gen_conv(x, 4*cnum, 3, 1, name='xconv8')
            x = gen_conv(x, 4*cnum, 3, rate=2, name='xconv9_atrous')
            x = gen_conv(x, 4*cnum, 3, rate=4, name='xconv10_atrous')
            x = gen_conv(x, 4*cnum, 3, rate=8, name='xconv11_atrous')
            x = gen_conv(x, 4*cnum, 3, rate=16, name='xconv12_atrous')
            x_hallu = x
            # attention branch
            x = gen_conv(xnow, cnum, 5, 1, name='pmconv1')
            x = gen_conv(x, cnum, 3, 2, name='pmconv2_downsample')
            x = gen_conv(x, 2*cnum, 3, 1, name='pmconv3')
            x = gen_conv(x, 4*cnum, 3, 2, name='pmconv4_downsample')
            x = gen_conv(x, 4*cnum, 3, 1, name='pmconv5')
            x = gen_conv(x, 4*cnum, 3, 1, name='pmconv6',
                                activation=tf.nn.relu)
            x, offset_flow = contextual_attention(x, x, mask_s, 3, 1, rate=2)
            x = gen_conv(x, 4*cnum, 3, 1, name='pmconv9')
            x = gen_conv(x, 4*cnum, 3, 1, name='pmconv10')
            pm = x
            x = tf.concat([x_hallu, pm], axis=3)

            x = gen_conv(x, 4*cnum, 3, 1, name='allconv11')
            x = gen_conv(x, 4*cnum, 3, 1, name='allconv12')
            x = gen_deconv(x, 2*cnum, name='allconv13_upsample')
            x = gen_conv(x, 2*cnum, 3, 1, name='allconv14')
            x = gen_deconv(x, cnum, name='allconv15_upsample')
            x = gen_conv(x, cnum//2, 3, 1, name='allconv16')
            x = gen_conv(x, 3, 3, 1, activation=None, name='allconv17')
            x = tf.nn.tanh(x)
            x_stage2 = x
        return x_stage1, x_stage2, offset_flow