def G_paper(self, z, last_resolution, current_resolution, name='G_paper',
                reuse=False):
        """Build graph for generator.
        Returns: tensor of image

        """
        assert last_resolution in [4, 8, 16, 32, 64, 128, 256, 512]
        assert current_resolution == last_resolution * 2
        get_cnum = lambda x: int(min(MAX_C, 2 ** (13 - np.log2(x))))

        x = z
        # with tf.variable_scope(name, reuse=(reuse or current_resolution!=8)):
        with tf.variable_scope(name, reuse=(reuse or False)):
            # [-1, 4, 4, 512]
            x = tf.reshape(x, [-1, 1, 1, 512])
            x = tf.layers.conv2d_transpose(
                x, 512, 4, 4, padding="same", activation=act, name='deconv_in')
            x = tf.layers.conv2d(
                x, 512, 3, padding='same', activation=act, name='conv_in')
        block_resolution = 4

        # with tf.variable_scope(name, reuse=True):
        with tf.variable_scope(name, reuse=(reuse or False)):
            for i in range(int(np.log2(current_resolution) - 3)):
                cnum = get_cnum(block_resolution)
                logger.info('Restore block, input resolution: {}, cnum: {}, '
                            'output resolution: {}.'.format(
                                block_resolution, cnum, block_resolution*2))
                x = resize(x, 2)
                block_resolution *= 2
                x = nn_block(x, cnum, name='block%s' % block_resolution)
            if current_resolution != 64:
                last_x = tf.layers.conv2d(
                    x, 3, 1, padding='same', name='%s_out' % block_resolution)

        with tf.variable_scope(name, reuse=(reuse or False)):
            cnum = get_cnum(block_resolution)
            logger.info('Add block, input resolution: {}, cnum: {}, '
                        'output resolution: {}.'.format(
                            block_resolution, cnum, block_resolution*2))
            x = resize(x, 2)
            block_resolution *= 2
            x = nn_block(x, cnum, name='block%s' % block_resolution)

            x = tf.layers.conv2d(
                x, 3, 1, padding='same', name='%s_out' % block_resolution)
            kt = progressive_kt('%s_kt' % block_resolution)

        if current_resolution != 64:
            x = kt * x + (1. - kt) * resize(last_x, 2)
        return x
Beispiel #2
0
def gated_deconv(x, cnum, name='upsample', padding='SAME', training=True):
    with tf.variable_scope(name):
        x = resize(x, func=tf.image.resize_nearest_neighbor)
        x = gated_conv(
            x, cnum, 3, 1, name=name+'_conv', padding=padding,
            training=training)
    return x
Beispiel #3
0
 def build_infer_graph(self, batch_data, config, bbox=None, name='val'):
     """
     """
     config.MAX_DELTA_HEIGHT = 0
     config.MAX_DELTA_WIDTH = 0
     if bbox is None:
         bbox = random_bbox(config)
     mask = bbox2mask(bbox, config, name=name + 'mask_c')
     batch_pos = batch_data / 127.5 - 1.
     edges = None
     batch_incomplete = batch_pos * (1. - mask)
     # inpaint
     x1, x2, offset_flow = self.build_inpaint_net(batch_incomplete,
                                                  mask,
                                                  config,
                                                  reuse=True,
                                                  training=False,
                                                  padding=config.PADDING)
     if config.PRETRAIN_COARSE_NETWORK:
         batch_predicted = x1
     else:
         batch_predicted = x2
     # apply mask and reconstruct
     batch_complete = batch_predicted * mask + batch_incomplete * (1. -
                                                                   mask)
     # global image visualization
     viz_img = [batch_pos, batch_incomplete, batch_complete]
     if offset_flow is not None:
         viz_img.append(
             resize(offset_flow,
                    scale=4,
                    func=tf.image.resize_nearest_neighbor))
     return batch_complete
Beispiel #4
0
def gen_deconv(x, cnum, ksize, stride, rate, name='upsample', padding='same', training=True):
    """Define deconv for generator.
    The deconv is defined to be a x2 resize_nearest_neighbor operation with
    additional gen_conv operation.

    Args:
        x: Input.
        cnum: Channel number.
        name: Name of layers.
        training: If current graph is for training or inference, used for bn.

    Returns:
        tf.Tensor: output

    """
    # Just using these as options to keep the signature the same as gen_conv.
    assert ksize == 3
    assert stride == 1
    assert rate == None
    with tf.variable_scope(name):
        x = resize(x, func=tf.image.resize_nearest_neighbor)
        x, layer = gen_conv(
            x, cnum, 3, 1, name=name+'_conv', padding=padding,
            training=training)
    return x, layer
Beispiel #5
0
def gen_deconv(x, cnum, name='upsample', padding='SAME', training=True):
    """Define deconv for generator.
    The deconv is defined to be a x2 resize_nearest_neighbor operation with
    additional gen_conv operation.

    Args:
        x: Input.
        cnum: Channel number.
        name: Name of layers.
        training: If current graph is for training or inference, used for bn.

    Returns:
        tf.Tensor: output

    """
    with tf.variable_scope(name):
        x = resize(x, func=tf.image.resize_nearest_neighbor)
        x = gen_conv(x,
                     cnum,
                     3,
                     1,
                     name=name + '_conv',
                     padding=padding,
                     training=training)
    return x
    def build_infer_graph(self, batch_data, config, name='val'):

        config.MAX_DELTA_HEIGHT = 0
        config.MAX_DELTA_WIDTH = 0

        mask = bbox2mask(config, name=name + 'mask_c')
        batch_pos = batch_data / 127.5 - 1.
        edges = None
        batch_incomplete = batch_pos * (1. - mask)
        x1, x2, offset_flow = self.build_inpaint_net(batch_incomplete,
                                                     mask,
                                                     config,
                                                     reuse=True,
                                                     training=False,
                                                     padding=config.PADDING)
        if config.PRETRAIN_COARSE_NETWORK:
            batch_predicted = x1
            logger.info('Set batch_predicted to x1.')
        else:
            batch_predicted = x2
            logger.info('Set batch_predicted to x2.')
        batch_complete = batch_predicted * mask + batch_incomplete * (1. -
                                                                      mask)
        viz_img = [batch_pos, batch_incomplete, batch_complete]
        if offset_flow is not None:
            viz_img.append(
                resize(offset_flow,
                       scale=4,
                       func=tf.image.resize_nearest_neighbor))
        images_summary(tf.concat(viz_img, axis=2),
                       name + '_raw_incomplete_complete', config.VIZ_MAX_OUT)
        return batch_complete
Beispiel #7
0
 def build_infer_graph(self, batch_data, batch_mask, batch_guide, config, name='val'):
     """
     validation
     """
     config.MAX_DELTA_HEIGHT = 0
     config.MAX_DELTA_WIDTH = 0
     batch_pos = batch_data / 127.5 - 1.
     batch_incomplete = batch_pos*(1.-batch_mask)
     # inpaint
     x1, x2, offset_flow = self.build_inpaint_net(
         batch_incomplete, batch_mask, batch_guide, config, reuse=True,
         training=False, padding=config.PADDING)
     if config.PRETRAIN_COARSE_NETWORK:
         batch_predicted = x1
         logger.info('Set batch_predicted to x1.')
     else:
         batch_predicted = x2
         logger.info('Set batch_predicted to x2.')
     # apply mask and reconstruct
     batch_complete = batch_predicted*batch_mask + batch_incomplete*(1.-batch_mask)
     # global image visualization
     viz_img = [batch_pos, batch_incomplete, batch_complete]
     if offset_flow is not None:
         viz_img.append(
             resize(offset_flow, scale=4,
                    func=tf.image.resize_nearest_neighbor))
     images_summary(
         tf.concat(viz_img, axis=2),
         name+'_raw_incomplete_complete', config.VIZ_MAX_OUT)
     return batch_complete
Beispiel #8
0
 def build_infer_graph(self, batch_data, config, bbox=None, name='val'):
     """
     """
     config.MAX_DELTA_HEIGHT = 0
     config.MAX_DELTA_WIDTH = 0
     if bbox is None:
         bbox = random_bbox(config)
     mask = bbox2mask(bbox, config, name=name+'mask_c')
     batch_pos = batch_data / 127.5 - 1.
     edges = None
     batch_incomplete = batch_pos*(1.-mask)
     # inpaint
     x2, offset_flow = self.build_inpaint_net(
         batch_incomplete, mask, config, reuse=True,
         training=False, padding=config.PADDING)
     batch_predicted = x2
     logger.info('Set batch_predicted to x2.')
     # apply mask and reconstruct
     batch_complete = batch_predicted*mask + batch_incomplete*(1.-mask)
     # global image visualization
     viz_img = [batch_pos, batch_incomplete, batch_complete]
     if offset_flow is not None:
         viz_img.append(
             resize(offset_flow, scale=4,
                    func=tf.image.resize_nearest_neighbor))
     images_summary(
         tf.concat(viz_img, axis=2),
         name+'_raw_incomplete_complete', config.VIZ_MAX_OUT)
     return batch_complete
Beispiel #9
0
    def downsampleMask_graph(self,
                             FLAGS,
                             batch_data,
                             mask,
                             downsample_rate,
                             name='val'):
        if FLAGS.guided:
            batch_data, edge = batch_data
            edge = edge[:, :, :, 0:1] / 255.
            edge = tf.cast(edge > FLAGS.edge_threshold, tf.float32)

        # mask = brush_stroke_mask(name='mask_c')
        # regular_mask = bbox2mask(bbox, name='mask_c')
        # irregular_mask = brush_stroke_mask(name='mask_c')
        # mask = tf.cast(
        #     tf.logical_or(
        #         tf.cast(irregular_mask, tf.bool),
        #         tf.cast(regular_mask, tf.bool),
        #     ),
        #     tf.float32
        # )

        # 事实上还是对输入的图像做了归一化 这里要改
        # modified
        # batch_pos = batch_data / 127.5 - 1.
        mean = tf.reduce_mean(tf.reduce_mean(batch_data * (1. - mask), 1),
                              2) * downsample_rate * downsample_rate
        batch_pos = batch_data / mean - 1
        batch_incomplete = batch_pos * (1. - mask)

        if FLAGS.guided:
            edge = edge * mask
            xin = tf.concat([batch_incomplete, edge], axis=3)
        else:
            xin = batch_incomplete
        # inpaint
        x1, x2, offset_flow = self.build_inpaint_net(xin,
                                                     mask,
                                                     reuse=True,
                                                     training=False,
                                                     padding=FLAGS.padding)
        batch_predicted = x2
        # apply mask and reconstruct
        batch_complete = batch_predicted * mask + batch_incomplete * (1. -
                                                                      mask)
        # global image visualization
        if FLAGS.guided:
            viz_img = [batch_pos, batch_incomplete + edge, batch_complete]
        else:
            viz_img = [batch_pos, batch_incomplete, batch_complete]
        if offset_flow is not None:
            viz_img.append(
                resize(offset_flow, scale=4, func=tf.image.resize_bilinear))
        images_summary(tf.concat(viz_img, axis=2),
                       name + '_raw_incomplete_complete', FLAGS.viz_max_out)
        return batch_complete
Beispiel #10
0
 def build_infer_graph(self,
                       batch_data,
                       config,
                       bbox=None,
                       name='val',
                       exclusionmask=None):
     """
     """
     config.MAX_DELTA_HEIGHT = 0
     config.MAX_DELTA_WIDTH = 0
     if bbox is None:
         bbox = random_bbox(config)
     mask = bbox2mask(bbox, config, name=name + 'mask_c')
     batch_pos = batch_data / 127.5 - 1.
     edges = None
     batch_incomplete = batch_pos * (1. - mask)
     # inpaint
     x1, x2, offset_flow = self.build_inpaint_net(batch_incomplete,
                                                  mask,
                                                  config,
                                                  reuse=True,
                                                  training=False,
                                                  padding=config.PADDING)
     if config.PRETRAIN_COARSE_NETWORK:
         batch_predicted = x1
         logger.info('Set batch_predicted to x1.')
     else:
         batch_predicted = x2
         logger.info('Set batch_predicted to x2.')
     # apply mask and reconstruct
     batch_complete = batch_predicted * mask + batch_incomplete * (1. -
                                                                   mask)
     # global image visualization
     img_size = [dim for dim in batch_incomplete.shape]
     img_size[2] = 5
     border = tf.zeros(tf.TensorShape(img_size))
     viz_img = [
         border, batch_pos, border, batch_incomplete, border,
         batch_complete, border
     ]
     if not config.PRETRAIN_COARSE_NETWORK:
         batch_complete_coarse = x1 * mask + batch_incomplete * (1. - mask)
         viz_img.append(batch_complete_coarse)
     if offset_flow is not None:
         scale = 2 << len(offset_flow)
         for flow in offset_flow:
             viz_img.append(
                 resize(flow,
                        scale=scale,
                        func=tf.image.resize_nearest_neighbor))
             viz_img.append(border)
             scale >>= 1
     images_summary(tf.concat(viz_img, axis=2),
                    name + '_raw_incomplete_complete', config.VIZ_MAX_OUT)
     return batch_complete
def resize_mask_like(mask, x):
    """Resize mask like shape of x.

    Args:
        mask: Original mask.
        x: To shape of x.

    Returns:
        tf.Tensor: resized mask

    """
    mask_resize = resize(
        mask, to_shape=x.get_shape().as_list()[1:3],
        func=tf.image.resize_nearest_neighbor)

    #Jaya's Code
    mask_resize =  resize(
        mask, scale=1./4, dynamic=True)

    return mask_resize
Beispiel #12
0
    def build_infer_graph(self, batch_data, config, bbox=None, name='val'):
        """
        """
        config.MAX_DELTA_HEIGHT = 0
        config.MAX_DELTA_WIDTH = 0
        if bbox is None:
            bbox = random_bbox(config)
        mask = bbox2mask(bbox, config, name=name + 'mask_c')
        batch_pos = batch_data / 127.5 - 1.

        edges = None
        batch_incomplete = batch_pos * (1. - mask)
        # inpaint
        if not config.ADD_GRADIENT_BRANCH:
            x1, x2, offset_flow = self.build_inpaint_net(
                batch_incomplete,
                mask,
                config,
                reuse=True,
                training=False,
                padding=config.PADDING)
        else:
            x1, x2, fake_gb, gb, offset_flow = self.build_inpaint_net(
                batch_incomplete,
                mask,
                config,
                reuse=True,
                training=False,
                padding=config.PADDING)
        if config.PRETRAIN_COARSE_NETWORK:
            batch_predicted = x1
            logger.info('Set batch_predicted to x1.')
        else:
            batch_predicted = x2
            logger.info('Set batch_predicted to x2.')
        # apply mask and reconstruct
        batch_complete = batch_predicted * mask + batch_incomplete * (1. -
                                                                      mask)
        # global image visualization
        viz_img = [batch_pos, batch_incomplete, batch_complete]

        if config.ADD_GRADIENT_BRANCH:
            gb_gt = self.get_grad.get_gradient_tf(batch_pos)
            batch_complete_gb = fake_gb * mask + gb_gt * (1. - mask)
            viz_img.extend([gb_gt, batch_complete_gb])

        if offset_flow is not None:
            viz_img.append(
                resize(offset_flow,
                       scale=4,
                       func=tf.image.resize_nearest_neighbor))
        images_summary(tf.concat(viz_img, axis=2),
                       name + '_raw_incomplete_complete', config.VIZ_MAX_OUT)
        return batch_complete
Beispiel #13
0
def resize_mask_like(mask, x):
    """Resize mask like shape of x.

    Args:
        mask: Original mask.
        x: To shape of x.

    Returns:
        tf.Tensor: resized mask

    """
    mask_resize = resize(mask,
                         to_shape=x.get_shape().as_list()[1:3],
                         func=tf.compat.v1.image.resize_nearest_neighbor)
    return mask_resize
    def build_infer_graph(self, FLAGS, batch_data, bbox=None, name='val'):
        """
        """
        if FLAGS.guided:
            batch_data, edge = batch_data
            edge = edge[:, :, :, 0:1] / 255.
            edge = tf.cast(edge > FLAGS.edge_threshold, tf.float32)
        mask = brush_stroke_mask(name='mask_c')
        regular_mask = bbox2mask(bbox, name='mask_c')
        irregular_mask = brush_stroke_mask(name='mask_c')
        mask = tf.cast(
            tf.logical_or(
                tf.cast(irregular_mask, tf.bool),
                tf.cast(regular_mask, tf.bool),
            ),
            tf.float32
        )

        batch_pos = batch_data / 127.5 - 1.
        batch_incomplete = batch_pos*(1.-mask)
        if FLAGS.guided:
            edge = edge * mask
            xin = tf.concat([batch_incomplete, edge], axis=3)
        else:
            xin = batch_incomplete
        # inpaint
        x1, x2, offset_flow = self.build_inpaint_net(
            xin, mask, reuse=True,
            training=False, padding=FLAGS.padding)
        batch_predicted = x2
        # apply mask and reconstruct
        batch_complete = batch_predicted*mask + batch_incomplete*(1.-mask)
        # global image visualization
        if FLAGS.guided:
            viz_img = [
                batch_pos,
                batch_incomplete + edge,
                batch_complete]
        else:
            viz_img = [batch_pos, batch_incomplete, batch_complete]
        if offset_flow is not None:
            viz_img.append(
                resize(offset_flow, scale=4,
                       func=tf.image.resize_bilinear))
        images_summary(
            tf.concat(viz_img, axis=2),
            name+'_raw_incomplete_complete', FLAGS.viz_max_out)
        return batch_complete
Beispiel #15
0
    def build_infer_graph(self,
                          batch_data,
                          batch_mask,
                          batch_guide,
                          config,
                          name='val'):
        """
        validation
        """
        batch_pos = batch_data / 127.5 - 1.
        if batch_mask is None:
            batch_mask = random_ff_mask(parts=8)
        else:
            pass

        batch_incomplete = batch_pos * (1. - batch_mask)
        ones_x = tf.ones_like(batch_mask)[:, :, :, 0:1]
        batch_mask = ones_x * batch_mask
        batch_guide = ones_x
        x1, x2, offset_flow = self.build_inpaint_net(batch_incomplete,
                                                     batch_mask,
                                                     batch_guide,
                                                     config,
                                                     reuse=True,
                                                     training=False,
                                                     padding=config.PADDING)
        if config.PRETRAIN_COARSE_NETWORK:
            batch_predicted = x1
            logger.info('Set batch_predicted to x1.')
        else:
            batch_predicted = x2
            logger.info('Set batch_predicted to x2.')
        # apply mask and complete image
        batch_complete = batch_predicted * batch_mask + batch_incomplete * (
            1. - batch_mask)
        # global image visualization
        viz_img = [
            batch_pos, batch_incomplete, batch_predicted, batch_complete
        ]
        if offset_flow is not None:
            viz_img.append(
                resize(offset_flow,
                       scale=4,
                       func=tf.image.resize_nearest_neighbor))
        images_summary(tf.concat(viz_img, axis=2),
                       name + '_raw_incomplete_complete', config.VIZ_MAX_OUT)
        return batch_complete
def resize_mask_like(mask, x):
    """Resize mask like shape of x.

    Args:
        mask: Original mask.
        x: To shape of x.

    Returns:
        tf.Tensor: resized mask

    """
    print('*******************\n***************resize_mask_like***************\n************************')
    mask_resize = resize(
        mask, to_shape=x.get_shape().as_list()[1:3],
        func=tf.image.resize_nearest_neighbor)


    print("resized mask_s",(mask_resize.shape.dims))
    return mask_resize
Beispiel #17
0
    def build_graph_with_losses(
            self, FLAGS, batch_data, training=True, summary=False,
            reuse=False):
        if FLAGS.guided:
            batch_data, edge = batch_data
            edge = edge[:, :, :, 0:1] / 255.
            edge = tf.cast(edge > FLAGS.edge_threshold, tf.float32)
        batch_pos = batch_data / 127.5 - 1.
        # generate mask, 1 represents masked point
        bbox = random_bbox(FLAGS)
        regular_mask = bbox2mask(FLAGS, bbox, name='mask_c')
        irregular_mask = brush_stroke_mask(FLAGS, name='mask_c')
        mask = tf.cast(
            tf.logical_or(
                tf.cast(irregular_mask, tf.bool),
                tf.cast(regular_mask, tf.bool),
            ),
            tf.float32
        )

        batch_incomplete = batch_pos*(1.-mask)
        if FLAGS.guided:
            edge = edge * mask
            xin = tf.concat([batch_incomplete, edge], axis=3)
        else:
            xin = batch_incomplete
        x1, x2, offset_flow = self.build_inpaint_net(
            xin, mask, reuse=reuse, training=training,
            padding=FLAGS.padding)
        batch_predicted = x2
        losses = {}
        # apply mask and complete image
        batch_complete = batch_predicted*mask + batch_incomplete*(1.-mask)
        # local patches
        losses['ae_loss'] = FLAGS.l1_loss_alpha * tf.reduce_mean(tf.abs(batch_pos - x1))
        losses['ae_loss'] += FLAGS.l1_loss_alpha * tf.reduce_mean(tf.abs(batch_pos - x2))
        if summary:
            scalar_summary('losses/ae_loss', losses['ae_loss'])
            if FLAGS.guided:
                viz_img = [
                    batch_pos,
                    batch_incomplete + edge,
                    batch_complete]
            else:
                viz_img = [batch_pos, batch_incomplete, batch_complete]
            if offset_flow is not None:
                viz_img.append(
                    resize(offset_flow, scale=4,
                           func=tf.image.resize_bilinear))
            images_summary(
                tf.concat(viz_img, axis=2),
                'raw_incomplete_predicted_complete', FLAGS.viz_max_out)

        # gan
        batch_pos_neg = tf.concat([batch_pos, batch_complete], axis=0)
        if FLAGS.gan_with_mask:
            batch_pos_neg = tf.concat([batch_pos_neg, tf.tile(mask, [FLAGS.batch_size*2, 1, 1, 1])], axis=3)
        if FLAGS.guided:
            # conditional GANs
            batch_pos_neg = tf.concat([batch_pos_neg, tf.tile(edge, [2, 1, 1, 1])], axis=3)
        # wgan with gradient penalty
        if FLAGS.gan == 'sngan':
            pos_neg = self.build_gan_discriminator(batch_pos_neg, training=training, reuse=reuse)
            pos, neg = tf.split(pos_neg, 2)
            g_loss, d_loss = gan_hinge_loss(pos, neg)
            losses['g_loss'] = g_loss
            losses['d_loss'] = d_loss
        else:
            raise NotImplementedError('{} not implemented.'.format(FLAGS.gan))
        if summary:
            # summary the magnitude of gradients from different losses w.r.t. predicted image
            gradients_summary(losses['g_loss'], batch_predicted, name='g_loss')
            gradients_summary(losses['g_loss'], x2, name='g_loss_to_x2')
            # gradients_summary(losses['ae_loss'], x1, name='ae_loss_to_x1')
            gradients_summary(losses['ae_loss'], x2, name='ae_loss_to_x2')
        losses['g_loss'] = FLAGS.gan_loss_alpha * losses['g_loss']
        if FLAGS.ae_loss:
            losses['g_loss'] += losses['ae_loss']
        g_vars = tf.get_collection(
            tf.GraphKeys.TRAINABLE_VARIABLES, 'inpaint_net')
        d_vars = tf.get_collection(
            tf.GraphKeys.TRAINABLE_VARIABLES, 'discriminator')
        return g_vars, d_vars, losses
Beispiel #18
0
    def build_graph_with_losses(self,
                                batch_data,
                                batch_mask,
                                batch_guide,
                                config,
                                training=True,
                                summary=False,
                                reuse=False):
        batch_pos = batch_data / 127.5 - 1.
        batch_mask = random_ff_mask(parts=8)
        batch_incomplete = batch_pos * (1. - batch_mask)
        ones_x = tf.ones_like(batch_mask)[:, :, :, 0:1]
        batch_mask = ones_x * batch_mask
        batch_guide = ones_x
        x1, x2, offset_flow = self.build_inpaint_net(batch_incomplete,
                                                     batch_mask,
                                                     batch_guide,
                                                     config,
                                                     reuse=reuse,
                                                     training=training,
                                                     padding=config.PADDING)
        if config.PRETRAIN_COARSE_NETWORK:
            batch_predicted = x1
            logger.info('Set batch_predicted to x1.')
        else:
            batch_predicted = x2
            logger.info('Set batch_predicted to x2.')
        losses = {}
        # apply mask and complete image
        batch_complete = batch_predicted * batch_mask + batch_incomplete * (
            1. - batch_mask)

        # Local patch is removed in the gated convolution. It's now AE + GAN loss
        # (https://github.com/JiahuiYu/generative_inpainting/issues/62)
        losses['ae_loss'] = config.COARSE_L1_ALPHA * tf.reduce_mean(
            tf.abs(batch_pos - x1) * (1. - batch_mask))
        losses['ae_loss'] += config.COARSE_L1_ALPHA * tf.reduce_mean(
            tf.abs(batch_pos - x2) * (1. - batch_mask))
        losses['ae_loss'] /= tf.reduce_mean(1. - batch_mask)

        if summary:
            scalar_summary('losses/ae_loss', losses['ae_loss'])
            viz_img = [
                batch_pos, batch_incomplete, batch_predicted, batch_complete
            ]  # I have included the predicted image as well to see the reconstructed image.
            if offset_flow is not None:
                viz_img.append(
                    resize(offset_flow,
                           scale=4,
                           func=tf.image.resize_nearest_neighbor))
            images_summary(tf.concat(viz_img, axis=2),
                           'raw_incomplete_predicted_complete',
                           config.VIZ_MAX_OUT)

        # gan
        batch_pos_neg = tf.concat([batch_pos, batch_complete], axis=0)
        batch_mask_all = tf.tile(batch_mask, [config.BATCH_SIZE * 2, 1, 1, 1])
        if config.GAN_WITH_MASK:
            batch_pos_neg = tf.concat([batch_pos_neg, batch_mask_all], axis=3)

        if config.GAN_WITH_GUIDE:
            batch_pos_neg = tf.concat([
                batch_pos_neg,
                tf.tile(batch_guide, [config.BATCH_SIZE * 2, 1, 1, 1])
            ],
                                      axis=3)
        # sn-pgan
        if config.GAN == 'sn_pgan':
            # sn path gan
            pos_neg = self.build_sn_pgan_discriminator(batch_pos_neg,
                                                       training=training,
                                                       reuse=reuse)
            pos_global, neg_global = tf.split(pos_neg, 2)

            # SNPGAN Loss

            g_loss_global, d_loss_global = gan_sn_pgan_loss(
                pos_global, neg_global, name='gan/global_gan')
            losses['g_loss'] = g_loss_global
            losses['d_loss'] = d_loss_global

        if summary and not config.PRETRAIN_COARSE_NETWORK:
            # summary the magnitude of gradients from different losses w.r.t. predicted image
            gradients_summary(losses['g_loss'], batch_predicted, name='g_loss')
            gradients_summary(losses['d_loss'], batch_predicted, name='d_loss')
            gradients_summary(losses['ae_loss'], x1, name='ae_loss_x1')
            gradients_summary(losses['ae_loss'], x2, name='ae_loss_x2')

        losses['g_loss'] = config.GAN_LOSS_ALPHA * losses['g_loss']
        losses['g_loss'] += config.AE_LOSS_ALPHA * losses['ae_loss']
        logger.info('Set GAN_LOSS_ALPHA to %f' % config.GAN_LOSS_ALPHA)
        g_vars = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES,
                                   'inpaint_net')
        d_vars = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES,
                                   'discriminator')
        return g_vars, d_vars, losses
    def build_graph_with_losses(self,
                                batch_data,
                                config,
                                training=True,
                                summary=False,
                                reuse=False):
        batch_pos = batch_data / 127.5 - 1.

        mask = bbox2mask(config, name='mask_c')
        batch_incomplete = batch_pos * (1. - mask)
        x1, x2, offset_flow = self.build_inpaint_net(batch_incomplete,
                                                     mask,
                                                     config,
                                                     reuse=reuse,
                                                     training=training,
                                                     padding=config.PADDING)
        if config.PRETRAIN_COARSE_NETWORK:
            batch_predicted = x1
            logger.info('Set batch_predicted to x1.')
        else:
            batch_predicted = x2
            logger.info('Set batch_predicted to x2.')
        losses = {}
        batch_complete = batch_predicted * mask + batch_incomplete * (1. -
                                                                      mask)
        local_patch_batch_pos = local_patch(batch_pos, mask)
        local_patch_batch_predicted = local_patch(batch_predicted, mask)
        local_patch_x1 = local_patch(x1, mask)
        local_patch_x2 = local_patch(x2, mask)
        local_patch_batch_complete = local_patch(batch_complete, mask)
        l1_alpha = config.COARSE_L1_ALPHA
        losses['l1_loss'] = l1_alpha * tf.reduce_mean(
            tf.abs(local_patch_batch_pos -
                   local_patch_x1))  #*spatial_discounting_mask(config))
        if not config.PRETRAIN_COARSE_NETWORK:
            losses['l1_loss'] += tf.reduce_mean(
                tf.abs(local_patch_batch_pos -
                       local_patch_x2))  #*spatial_discounting_mask(config))
        losses['ae_loss'] = l1_alpha * tf.reduce_mean(
            tf.abs(batch_pos - x1) * (1. - mask))
        if not config.PRETRAIN_COARSE_NETWORK:
            losses['ae_loss'] += tf.reduce_mean(
                tf.abs(batch_pos - x2) * (1. - mask))
        losses['ae_loss'] /= tf.reduce_mean(1. - mask)
        if summary:
            scalar_summary('losses/l1_loss', losses['l1_loss'])
            scalar_summary('losses/ae_loss', losses['ae_loss'])
            viz_img = [batch_pos, batch_incomplete, batch_complete]
            if offset_flow is not None:
                viz_img.append(
                    resize(offset_flow,
                           scale=4,
                           func=tf.image.resize_nearest_neighbor))
            images_summary(tf.concat(viz_img, axis=2),
                           'raw_incomplete_predicted_complete',
                           config.VIZ_MAX_OUT)

        batch_pos_neg = tf.concat([batch_pos, batch_complete], axis=0)
        local_patch_batch_pos_neg = tf.concat(
            [local_patch_batch_pos, local_patch_batch_complete], 0)
        if config.GAN_WITH_MASK:
            batch_pos_neg = tf.concat([
                batch_pos_neg,
                tf.tile(mask, [config.BATCH_SIZE * 2, 1, 1, 1])
            ],
                                      axis=3)

        if config.GAN == 'snpatch_gan':
            pos_neg = self.build_SNGAN_discriminator(local_patch_batch_pos_neg,
                                                     training=training,
                                                     reuse=reuse)
            pos, neg = tf.split(pos_neg, 2)
            sn_gloss, sn_dloss = self.gan_hinge_loss(pos,
                                                     neg,
                                                     name="gan/hinge_loss")
            losses['g_loss'] = config.GLOBAL_WGAN_LOSS_ALPHA * sn_gloss
            losses['d_loss'] = sn_dloss
            interpolates = random_interpolates(a1, a2)
            dout = self.build_SNGAN_discriminator(interpolates, reuse=True)
            penalty = gradients_penalty(interpolates, dout, mask=mask)
            losses['gp_loss'] = config.WGAN_GP_LAMBDA * penalty
            losses['d_loss'] = losses['d_loss'] + losses['gp_loss']
            if summary and not config.PRETRAIN_COARSE_NETWORK:
                gradients_summary(sn_gloss,
                                  batch_predicted,
                                  name='g_loss_local')
                scalar_summary('convergence/d_loss', losses['d_loss'])
                scalar_summary('convergence/local_d_loss', sn_dloss)
                scalar_summary('gan_wgan_loss/gp_loss', losses['gp_loss'])
                scalar_summary('gan_wgan_loss/gp_penalty_local', penalty)

        if summary and not config.PRETRAIN_COARSE_NETWORK:
            gradients_summary(losses['g_loss'], batch_predicted, name='g_loss')
            gradients_summary(losses['g_loss'], x1, name='g_loss_to_x1')
            gradients_summary(losses['g_loss'], x2, name='g_loss_to_x2')
            gradients_summary(losses['l1_loss'], x1, name='l1_loss_to_x1')
            gradients_summary(losses['l1_loss'], x2, name='l1_loss_to_x2')
            gradients_summary(losses['ae_loss'], x1, name='ae_loss_to_x1')
            gradients_summary(losses['ae_loss'], x2, name='ae_loss_to_x2')
        if config.PRETRAIN_COARSE_NETWORK:
            losses['g_loss'] = 0
        else:
            losses['g_loss'] = config.GAN_LOSS_ALPHA * losses['g_loss']
        losses['g_loss'] += config.L1_LOSS_ALPHA * losses['l1_loss']
        logger.info('Set L1_LOSS_ALPHA to %f' % config.L1_LOSS_ALPHA)
        logger.info('Set GAN_LOSS_ALPHA to %f' % config.GAN_LOSS_ALPHA)
        if config.AE_LOSS:
            losses['g_loss'] += config.AE_LOSS_ALPHA * losses['ae_loss']
            logger.info('Set AE_LOSS_ALPHA to %f' % config.AE_LOSS_ALPHA)
        g_vars = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES,
                                   'inpaint_net')
        d_vars = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES,
                                   'discriminator')
        return g_vars, d_vars, losses
    def build_graph_with_losses(self, batch_data, config, training=True,
                                summary=False, reuse=False):
        batch_pos = batch_data / 127.5 - 1.
        # generate mask, 1 represents masked point
        bbox = random_bbox(config)
        mask = bbox2mask(bbox, config, name='mask_c')
        batch_incomplete = batch_pos*(1.-mask)
        x1, x2, offset_flow = self.build_inpaint_net(
            batch_incomplete, mask, config, reuse=reuse, training=training,
            padding=config.PADDING)
        if config.PRETRAIN_COARSE_NETWORK:
            batch_predicted = x1
            logger.info('Set batch_predicted to x1.')
        else:
            batch_predicted = x2
            logger.info('Set batch_predicted to x2.')
        losses = {}
        # apply mask and complete image
        batch_complete = batch_predicted*mask + batch_incomplete*(1.-mask)
        # local patches
        local_patch_batch_pos = local_patch(batch_pos, bbox)
        local_patch_batch_predicted = local_patch(batch_predicted, bbox)
        local_patch_x1 = local_patch(x1, bbox)
        local_patch_x2 = local_patch(x2, bbox)
        local_patch_batch_complete = local_patch(batch_complete, bbox)
        local_patch_mask = local_patch(mask, bbox)
        l1_alpha = config.COARSE_L1_ALPHA
        losses['l1_loss'] = l1_alpha * tf.reduce_mean(tf.abs(local_patch_batch_pos - local_patch_x1)*spatial_discounting_mask(config))
        if not config.PRETRAIN_COARSE_NETWORK:
            losses['l1_loss'] += tf.reduce_mean(tf.abs(local_patch_batch_pos - local_patch_x2)*spatial_discounting_mask(config))
        losses['ae_loss'] = l1_alpha * tf.reduce_mean(tf.abs(batch_pos - x1) * (1.-mask))
        if not config.PRETRAIN_COARSE_NETWORK:
            losses['ae_loss'] += tf.reduce_mean(tf.abs(batch_pos - x2) * (1.-mask))
        losses['ae_loss'] /= tf.reduce_mean(1.-mask)
        if summary:
            scalar_summary('losses/l1_loss', losses['l1_loss'])
            scalar_summary('losses/ae_loss', losses['ae_loss'])
            viz_img = [batch_pos, batch_incomplete, batch_complete]
            if offset_flow is not None:
                viz_img.append(
                    resize(offset_flow, scale=4,
                           func=tf.image.resize_nearest_neighbor))
            images_summary(
                tf.concat(viz_img, axis=2),
                'raw_incomplete_predicted_complete', config.VIZ_MAX_OUT)

        # gan
        batch_pos_neg = tf.concat([batch_pos, batch_complete], axis=0)
        # local deterministic patch
        local_patch_batch_pos_neg = tf.concat([local_patch_batch_pos, local_patch_batch_complete], 0)
        if config.GAN_WITH_MASK:
            batch_pos_neg = tf.concat([batch_pos_neg, tf.tile(mask, [config.BATCH_SIZE*2, 1, 1, 1])], axis=3)
        # wgan with gradient penalty
        if config.GAN == 'wgan_gp':
            # seperate gan
            pos_neg_local, pos_neg_global = self.build_wgan_discriminator(local_patch_batch_pos_neg, batch_pos_neg, training=training, reuse=reuse)
            pos_local, neg_local = tf.split(pos_neg_local, 2)
            pos_global, neg_global = tf.split(pos_neg_global, 2)
            # wgan loss
            g_loss_local, d_loss_local = gan_wgan_loss(pos_local, neg_local, name='gan/local_gan')
            g_loss_global, d_loss_global = gan_wgan_loss(pos_global, neg_global, name='gan/global_gan')
            losses['g_loss'] = config.GLOBAL_WGAN_LOSS_ALPHA * g_loss_global + g_loss_local
            losses['d_loss'] = d_loss_global + d_loss_local
            # gp
            interpolates_local = random_interpolates(local_patch_batch_pos, local_patch_batch_complete)
            interpolates_global = random_interpolates(batch_pos, batch_complete)
            dout_local, dout_global = self.build_wgan_discriminator(
                interpolates_local, interpolates_global, reuse=True)
            # apply penalty
            penalty_local = gradients_penalty(interpolates_local, dout_local, mask=local_patch_mask)
            penalty_global = gradients_penalty(interpolates_global, dout_global, mask=mask)
            losses['gp_loss'] = config.WGAN_GP_LAMBDA * (penalty_local + penalty_global)
            losses['d_loss'] = losses['d_loss'] + losses['gp_loss']
            if summary and not config.PRETRAIN_COARSE_NETWORK:
                gradients_summary(g_loss_local, batch_predicted, name='g_loss_local')
                gradients_summary(g_loss_global, batch_predicted, name='g_loss_global')
                scalar_summary('convergence/d_loss', losses['d_loss'])
                scalar_summary('convergence/local_d_loss', d_loss_local)
                scalar_summary('convergence/global_d_loss', d_loss_global)
                scalar_summary('gan_wgan_loss/gp_loss', losses['gp_loss'])
                scalar_summary('gan_wgan_loss/gp_penalty_local', penalty_local)
                scalar_summary('gan_wgan_loss/gp_penalty_global', penalty_global)

        if summary and not config.PRETRAIN_COARSE_NETWORK:
            # summary the magnitude of gradients from different losses w.r.t. predicted image
            gradients_summary(losses['g_loss'], batch_predicted, name='g_loss')
            gradients_summary(losses['g_loss'], x1, name='g_loss_to_x1')
            gradients_summary(losses['g_loss'], x2, name='g_loss_to_x2')
            gradients_summary(losses['l1_loss'], x1, name='l1_loss_to_x1')
            gradients_summary(losses['l1_loss'], x2, name='l1_loss_to_x2')
            gradients_summary(losses['ae_loss'], x1, name='ae_loss_to_x1')
            gradients_summary(losses['ae_loss'], x2, name='ae_loss_to_x2')
        if config.PRETRAIN_COARSE_NETWORK:
            losses['g_loss'] = 0
        else:
            losses['g_loss'] = config.GAN_LOSS_ALPHA * losses['g_loss']
        losses['g_loss'] += config.L1_LOSS_ALPHA * losses['l1_loss']
        logger.info('Set L1_LOSS_ALPHA to %f' % config.L1_LOSS_ALPHA)
        logger.info('Set GAN_LOSS_ALPHA to %f' % config.GAN_LOSS_ALPHA)
        if config.AE_LOSS:
            losses['g_loss'] += config.AE_LOSS_ALPHA * losses['ae_loss']
            logger.info('Set AE_LOSS_ALPHA to %f' % config.AE_LOSS_ALPHA)
        g_vars = tf.get_collection(
            tf.GraphKeys.TRAINABLE_VARIABLES, 'inpaint_net')
        d_vars = tf.get_collection(
            tf.GraphKeys.TRAINABLE_VARIABLES, 'discriminator')
        return g_vars, d_vars, losses
Beispiel #21
0
def contextual_attention(f,
                         b,
                         mask=None,
                         ksize=3,
                         stride=1,
                         rate=1,
                         fuse_k=3,
                         softmax_scale=10.,
                         training=True,
                         fuse=True):
    """ Contextual attention layer implementation.

    Contextual attention is first introduced in publication:
        Generative Image Inpainting with Contextual Attention, Yu et al.

    Args:
        x: Input feature to match (foreground).
        t: Input feature for match (background).
        mask: Input mask for t, indicating patches not available.
        ksize: Kernel size for contextual attention.
        stride: Stride for extracting patches from t.
        rate: Dilation for matching.
        softmax_scale: Scaled softmax for attention.
        training: Indicating if current graph is training or inference.

    Returns:
        tf.Tensor: output

    """
    # get shapes
    raw_fs = tf.shape(f)
    raw_int_fs = f.get_shape().as_list()
    raw_int_bs = b.get_shape().as_list()
    # extract patches from background with stride and rate
    kernel = 2 * rate
    raw_w = tf.extract_image_patches(b, [1, kernel, kernel, 1],
                                     [1, rate * stride, rate * stride, 1],
                                     [1, 1, 1, 1],
                                     padding='SAME')
    raw_w = tf.reshape(raw_w,
                       [raw_int_bs[0], -1, kernel, kernel, raw_int_bs[3]])
    raw_w = tf.transpose(raw_w, [0, 2, 3, 4, 1])  # transpose to b*k*k*c*hw
    # downscaling foreground option: downscaling both foreground and
    # background for matching and use original background for reconstruction.
    f = resize(f, scale=1. / rate, func=tf.image.resize_nearest_neighbor)
    b = resize(b,
               to_shape=[int(raw_int_bs[1] / rate),
                         int(raw_int_bs[2] / rate)],
               func=tf.image.resize_nearest_neighbor
               )  # https://github.com/tensorflow/tensorflow/issues/11651
    if mask is not None:
        mask = resize(mask,
                      scale=1. / rate,
                      func=tf.image.resize_nearest_neighbor)
    fs = tf.shape(f)
    int_fs = f.get_shape().as_list()
    f_groups = tf.split(f, int_fs[0], axis=0)
    # from t(H*W*C) to w(b*k*k*c*h*w)
    bs = tf.shape(b)
    int_bs = b.get_shape().as_list()
    w = tf.extract_image_patches(b, [1, ksize, ksize, 1],
                                 [1, stride, stride, 1], [1, 1, 1, 1],
                                 padding='SAME')
    w = tf.reshape(w, [int_fs[0], -1, ksize, ksize, int_fs[3]])
    w = tf.transpose(w, [0, 2, 3, 4, 1])  # transpose to b*k*k*c*hw
    # process mask
    if mask is None:
        mask = tf.zeros([1, bs[1], bs[2], 1])
    m = tf.extract_image_patches(mask, [1, ksize, ksize, 1],
                                 [1, stride, stride, 1], [1, 1, 1, 1],
                                 padding='SAME')
    m = tf.reshape(m, [1, -1, ksize, ksize, 1])
    m = tf.transpose(m, [0, 2, 3, 4, 1])  # transpose to b*k*k*c*hw
    m = m[0]
    mm = tf.cast(
        tf.equal(tf.reduce_mean(m, axis=[0, 1, 2], keep_dims=True), 0.),
        tf.float32)
    w_groups = tf.split(w, int_bs[0], axis=0)
    raw_w_groups = tf.split(raw_w, int_bs[0], axis=0)
    y = []
    offsets = []
    k = fuse_k
    scale = softmax_scale
    fuse_weight = tf.reshape(tf.eye(k), [k, k, 1, 1])
    for xi, wi, raw_wi in zip(f_groups, w_groups, raw_w_groups):
        # conv for compare
        wi = wi[0]
        wi_normed = wi / tf.maximum(
            tf.sqrt(tf.reduce_sum(tf.square(wi), axis=[0, 1, 2])), 1e-4)
        yi = tf.nn.conv2d(xi, wi_normed, strides=[1, 1, 1, 1], padding="SAME")

        # conv implementation for fuse scores to encourage large patches
        if fuse:
            yi = tf.reshape(yi, [1, fs[1] * fs[2], bs[1] * bs[2], 1])
            yi = tf.nn.conv2d(yi,
                              fuse_weight,
                              strides=[1, 1, 1, 1],
                              padding='SAME')
            yi = tf.reshape(yi, [1, fs[1], fs[2], bs[1], bs[2]])
            yi = tf.transpose(yi, [0, 2, 1, 4, 3])
            yi = tf.reshape(yi, [1, fs[1] * fs[2], bs[1] * bs[2], 1])
            yi = tf.nn.conv2d(yi,
                              fuse_weight,
                              strides=[1, 1, 1, 1],
                              padding='SAME')
            yi = tf.reshape(yi, [1, fs[2], fs[1], bs[2], bs[1]])
            yi = tf.transpose(yi, [0, 2, 1, 4, 3])
        yi = tf.reshape(yi, [1, fs[1], fs[2], bs[1] * bs[2]])

        # softmax to match
        yi *= mm  # mask
        yi = tf.nn.softmax(yi * scale, 3)
        yi *= mm  # mask

        offset = tf.argmax(yi, axis=3, output_type=tf.int32)
        offset = tf.stack([offset // fs[2], offset % fs[2]], axis=-1)
        # deconv for patch pasting
        # 3.1 paste center
        wi_center = raw_wi[0]
        yi = tf.nn.conv2d_transpose(yi,
                                    wi_center,
                                    tf.concat([[1], raw_fs[1:]], axis=0),
                                    strides=[1, rate, rate, 1]) / 4.
        y.append(yi)
        offsets.append(offset)
    y = tf.concat(y, axis=0)
    y.set_shape(raw_int_fs)
    offsets = tf.concat(offsets, axis=0)
    offsets.set_shape(int_bs[:3] + [2])
    # case1: visualize optical flow: minus current position
    h_add = tf.tile(tf.reshape(tf.range(bs[1]), [1, bs[1], 1, 1]),
                    [bs[0], 1, bs[2], 1])
    w_add = tf.tile(tf.reshape(tf.range(bs[2]), [1, 1, bs[2], 1]),
                    [bs[0], bs[1], 1, 1])
    offsets = offsets - tf.concat([h_add, w_add], axis=3)
    # to flow image
    flow = flow_to_image_tf(offsets)
    # # case2: visualize which pixels are attended
    # flow = highlight_flow_tf(offsets * tf.cast(mask, tf.int32))
    if rate != 1:
        flow = resize(flow, scale=rate, func=tf.image.resize_nearest_neighbor)
    return y, flow
Beispiel #22
0
def resize_mask_like(mask, x):
    mask_resize = resize(mask,
                         to_shape=x.get_shape().as_list()[1:3],
                         func=tf.image.resize_nearest_neighbor)
    return mask_resize
Beispiel #23
0
def contextual_attention(f,
                         b,
                         mask=None,
                         ksize=3,
                         stride=1,
                         rate=1,
                         fuse_k=3,
                         softmax_scale=10.,
                         training=True,
                         fuse=True):
    raw_fs = tf.shape(f)
    raw_int_fs = f.get_shape().as_list()
    raw_int_bs = b.get_shape().as_list()

    kernel = 2 * rate
    raw_w = tf.extract_image_patches(b, [1, kernel, kernel, 1],
                                     [1, rate * stride, rate * stride, 1],
                                     [1, 1, 1, 1],
                                     padding='SAME')
    raw_w = tf.reshape(raw_w,
                       [raw_int_bs[0], -1, kernel, kernel, raw_int_bs[3]])
    raw_w = tf.transpose(raw_w, [0, 2, 3, 4, 1])

    f = resize(f, scale=1. / rate, func=tf.image.resize_nearest_neighbor)
    b = resize(b,
               to_shape=[int(raw_int_bs[1] / rate),
                         int(raw_int_bs[2] / rate)],
               func=tf.image.resize_nearest_neighbor)
    if mask is not None:
        mask = resize(mask,
                      scale=1. / rate,
                      func=tf.image.resize_nearest_neighbor)
    fs = tf.shape(f)
    int_fs = f.get_shape().as_list()
    f_groups = tf.split(f, int_fs[0], axis=0)

    bs = tf.shape(b)
    int_bs = b.get_shape().as_list()
    w = tf.extract_image_patches(b, [1, ksize, ksize, 1],
                                 [1, stride, stride, 1], [1, 1, 1, 1],
                                 padding='SAME')
    w = tf.reshape(w, [int_fs[0], -1, ksize, ksize, int_fs[3]])
    w = tf.transpose(w, [0, 2, 3, 4, 1])

    if mask is None:
        mask = tf.zeros([1, bs[1], bs[2], 1])
    m = tf.extract_image_patches(mask, [1, ksize, ksize, 1],
                                 [1, stride, stride, 1], [1, 1, 1, 1],
                                 padding='SAME')
    m = tf.reshape(m, [1, -1, ksize, ksize, 1])
    m = tf.transpose(m, [0, 2, 3, 4, 1])
    m = m[0]
    mm = tf.cast(
        tf.equal(tf.reduce_mean(m, axis=[0, 1, 2], keep_dims=True), 0.),
        tf.float32)
    w_groups = tf.split(w, int_bs[0], axis=0)
    raw_w_groups = tf.split(raw_w, int_bs[0], axis=0)
    y = []
    offsets = []
    k = fuse_k
    scale = softmax_scale
    fuse_weight = tf.reshape(tf.eye(k), [k, k, 1, 1])
    for xi, wi, raw_wi in zip(f_groups, w_groups, raw_w_groups):

        wi = wi[0]
        wi_normed = wi / tf.maximum(
            tf.sqrt(tf.reduce_sum(tf.square(wi), axis=[0, 1, 2])), 1e-4)
        yi = tf.nn.conv2d(xi, wi_normed, strides=[1, 1, 1, 1], padding="SAME")

        if fuse:
            yi = tf.reshape(yi, [1, fs[1] * fs[2], bs[1] * bs[2], 1])
            yi = tf.nn.conv2d(yi,
                              fuse_weight,
                              strides=[1, 1, 1, 1],
                              padding='SAME')
            yi = tf.reshape(yi, [1, fs[1], fs[2], bs[1], bs[2]])
            yi = tf.transpose(yi, [0, 2, 1, 4, 3])
            yi = tf.reshape(yi, [1, fs[1] * fs[2], bs[1] * bs[2], 1])
            yi = tf.nn.conv2d(yi,
                              fuse_weight,
                              strides=[1, 1, 1, 1],
                              padding='SAME')
            yi = tf.reshape(yi, [1, fs[2], fs[1], bs[2], bs[1]])
            yi = tf.transpose(yi, [0, 2, 1, 4, 3])
        yi = tf.reshape(yi, [1, fs[1], fs[2], bs[1] * bs[2]])

        yi *= mm
        yi = tf.nn.softmax(yi * scale, 3)
        yi *= mm

        offset = tf.argmax(yi, axis=3, output_type=tf.int32)
        offset = tf.stack([offset // fs[2], offset % fs[2]], axis=-1)

        wi_center = raw_wi[0]
        yi = tf.nn.conv2d_transpose(yi,
                                    wi_center,
                                    tf.concat([[1], raw_fs[1:]], axis=0),
                                    strides=[1, rate, rate, 1]) / 4.
        y.append(yi)
        offsets.append(offset)
    y = tf.concat(y, axis=0)
    y.set_shape(raw_int_fs)
    offsets = tf.concat(offsets, axis=0)
    offsets.set_shape(int_bs[:3] + [2])

    h_add = tf.tile(tf.reshape(tf.range(bs[1]), [1, bs[1], 1, 1]),
                    [bs[0], 1, bs[2], 1])
    w_add = tf.tile(tf.reshape(tf.range(bs[2]), [1, 1, bs[2], 1]),
                    [bs[0], bs[1], 1, 1])
    offsets = offsets - tf.concat([h_add, w_add], axis=3)

    flow = flow_to_image_tf(offsets)

    if rate != 1:
        flow = resize(flow, scale=rate, func=tf.image.resize_bilinear)
    return y, flow
def contextual_attention(f, b, mask=None, ksize=3, stride=1, rate=1,
                         fuse_k=3, softmax_scale=10., training=True, fuse=True):
    """ Contextual attention layer implementation.

    Contextual attention is first introduced in publication:
        Generative Image Inpainting with Contextual Attention, Yu et al.

    Args:
        x: Input feature to match (foreground).
        t: Input feature for match (background).
        mask: Input mask for t, indicating patches not available.
        ksize: Kernel size for contextual attention.
        stride: Stride for extracting patches from t.
        rate: Dilation for matching.
        softmax_scale: Scaled softmax for attention.
        training: Indicating if current graph is training or inference.

    Returns:
        tf.Tensor: output

    """
    # get shapes
    raw_fs = tf.shape(f)
    print("raw_fs",raw_fs.shape)
    
    raw_int_fs = f.get_shape().as_list()
    print("raw_int_fs",raw_int_fs)
    #foreground shape
    raw_int_bs = b.get_shape().as_list()
    print("raw_int_bs",raw_int_bs)
    #background shape
    '''
    raw_fs (4,)
    raw_int_fs [2, 64, 64, 128]
    raw_int_bs [2, 64, 64, 128]

    '''
    # extract patches from background with stride and rate
    kernel = 2*rate
    raw_w = tf.extract_image_patches(
        b, [1,kernel,kernel,1], [1,rate*stride,rate*stride,1], [1,1,1,1], padding='SAME')
    print("raw_w or extracted patches",raw_w)


    raw_w = tf.reshape(raw_w, [raw_int_bs[0], -1, kernel, kernel, raw_int_bs[3]])
    raw_w = tf.transpose(raw_w, [0, 2, 3, 4, 1])
    #####background patches. These patches have to be the kernels
    print("transposed raw_w or extracted patches",raw_w.shape)
    # transpose to b*k*k*c*hw


    # downscaling foreground option: downscaling both foreground and
    # background for matching and use original background for reconstruction.

    ######foreground and back ground need to be downscaled because?
    ### both are downscaled in same shape

    f = resize(f, scale=1./rate, func=tf.image.resize_nearest_neighbor)
    print("rdownscaled f",f.shape)
    b = resize(b, to_shape=[int(raw_int_bs[1]/rate), int(raw_int_bs[2]/rate)], func=tf.image.resize_nearest_neighbor)  # https://github.com/tensorflow/tensorflow/issues/11651
    print("rdownscaled b",b.shape)
    '''
    rdownscaled f (2, 32, 32, 128)
    rdownscaled b (2, 32, 32, 128)
    '''

    #######
    if mask is not None:
        mask = resize(mask, scale=1./rate, func=tf.image.resize_nearest_neighbor)
    print("again resized mask",mask.shape)
    #again resized mask (1, 32, 32, 1) 

    #########after downscaling

    fs = tf.shape(f)
    print("fs:",fs)
    #int_fs: [2, 32, 32, 128]
    int_fs = f.get_shape().as_list()
    print("int_fs:",int_fs)
    f_groups = tf.split(f, int_fs[0], axis=0)
    print("splitted f_groups:",f_groups)

    #tf.split(X, row = n, column = m) is used to split the data set of the variable into n number of pieces row wise and m numbers of pieces column wise.

    #For example, we have data_set x of size (10,10), then tf.split(x, 2, 0) will break the data_set of x in 2 set of size (5, 10)

    #but if we take tf.split(x, 2, 2), then we will get 4 sets of data of size (5, 5).
    # from t(H*W*C) to w(b*k*k*c*h*w)
    bs = tf.shape(b)
    print("bs:",bs)
    int_bs = b.get_shape().as_list()
    print("int_bs:",int_bs)

    w = tf.extract_image_patches(
        b, [1,ksize,ksize,1], [1,stride,stride,1], [1,1,1,1], padding='SAME')

    '''
    w or extracted patches (2, 32, 32, 1152)
    transposed w or extracted patches (2, 3, 3, 128, 1024)
    extracted mask patch shape (1, 32, 32, 9)
    transposed extracted mask patch shape (3, 3, 1, 1024)
    temporary (1, 1, 1, 1024)
    mm shape (1, 1, 1, 1024)
    splitted w_groups: [<tf.Tensor 'inpaint_net/split_1:0' shape=(1, 3, 3, 128, 1024) dtype=float32>, <tf.Tensor 'inpaint_net/split_1:1' shape=(1, 3, 3, 128, 1024) dtype=float32>]
    splitted raw_w_groups: [<tf.Tensor 'inpaint_net/split_2:0' shape=(1, 4, 4, 128, 1024) dtype=float32>, <tf.Tensor 'inpaint_net/split_2:1' shape=(1, 4, 4, 128, 1024) dtype=float32>]
    yi shape (1, 32, 32, 1024)
    yi shape after multiplying mm Tensor("inpaint_net/mul_8:0", shape=(1, 32, 32, 1024), dtype=float32)
    yi after softmax shape (1, 32, 32, 1024)
    yi shape (1, 32, 32, 1024)
    yi shape after multiplying mm Tensor("inpaint_net/mul_16:0", shape=(1, 32, 32, 1024), dtype=float32)
    yi after softmax shape (1, 32, 32, 1024)
    x_hallu [Dimension(2), Dimension(256), Dimension(256), Dimension(3)]
    '''
    print("w or extracted patches",w.shape)
    w = tf.reshape(w, [int_fs[0], -1, ksize, ksize, int_fs[3]])
    w = tf.transpose(w, [0, 2, 3, 4, 1])  # transpose to b*k*k*c*hw
    print("transposed w or extracted patches",w.shape)
    ###again
    # process mask
    if mask is None:
        mask = tf.zeros([1, bs[1], bs[2], 1])


    m = tf.extract_image_patches(
        mask, [1,ksize,ksize,1], [1,stride,stride,1], [1,1,1,1], padding='SAME')
    print("extracted mask patch shape",m.shape)
    m = tf.reshape(m, [1, -1, ksize, ksize, 1])
    m = tf.transpose(m, [0, 2, 3, 4, 1])  # transpose to b*k*k*c*hw
    m = m[0]
    print("transposed extracted mask patch shape",m.shape)
    temporary=tf.equal(tf.reduce_mean(m, axis=[0,1,2], keep_dims=True), 0.)
    print('temporary',temporary.shape)
    mm = tf.cast(temporary, tf.float32)
    print("mm shape",mm.shape)
    w_groups = tf.split(w, int_bs[0], axis=0)
    raw_w_groups = tf.split(raw_w, int_bs[0], axis=0)
    print("splitted w_groups:",w_groups)
    print("splitted raw_w_groups:",raw_w_groups)
    y = []
    offsets = []
    k = fuse_k
    scale = softmax_scale
    fuse_weight = tf.reshape(tf.eye(k), [k, k, 1, 1])
    for xi, wi, raw_wi in zip(f_groups, w_groups, raw_w_groups):
        # conv for compare
        
        wi = wi[0]
        wi_normed = wi / tf.maximum(tf.sqrt(tf.reduce_sum(tf.square(wi), axis=[0,1,2])), 1e-4)
        ## normalize each background patch
        yi = tf.nn.conv2d(xi, wi_normed, strides=[1,1,1,1], padding="SAME")

        # conv implementation for fuse scores to encourage large patches
        if fuse:
            yi = tf.reshape(yi, [1, fs[1]*fs[2], bs[1]*bs[2], 1])
            yi = tf.nn.conv2d(yi, fuse_weight, strides=[1,1,1,1], padding='SAME')
            yi = tf.reshape(yi, [1, fs[1], fs[2], bs[1], bs[2]])
            yi = tf.transpose(yi, [0, 2, 1, 4, 3])
            yi = tf.reshape(yi, [1, fs[1]*fs[2], bs[1]*bs[2], 1])
            yi = tf.nn.conv2d(yi, fuse_weight, strides=[1,1,1,1], padding='SAME')
            yi = tf.reshape(yi, [1, fs[2], fs[1], bs[2], bs[1]])
            yi = tf.transpose(yi, [0, 2, 1, 4, 3])
        yi = tf.reshape(yi, [1, fs[1], fs[2], bs[1]*bs[2]])
        #pr=tf.Print(yi,[yi],"Jisa says:")
        print("yi shape",yi.shape)
        # softmax to match
        yi *=  mm  # mask
        print("yi shape after multiplying mm",yi)
        yi = tf.nn.softmax(yi*scale, 3)
        #pr=tf.Print(yi,[yi],"Jisa says:")
        yi *=  mm  # mask
        print("yi after softmax shape",yi.shape)
        offset = tf.argmax(yi, axis=3, output_type=tf.int32)
        offset = tf.stack([offset // fs[2], offset % fs[2]], axis=-1)
        # deconv for patch pasting
        # 3.1 paste center
        wi_center = raw_wi[0]
        yi = tf.nn.conv2d_transpose(yi, wi_center, tf.concat([[1], raw_fs[1:]], axis=0), strides=[1,rate,rate,1]) / 4.
        y.append(yi)
        offsets.append(offset)
    y = tf.concat(y, axis=0)
    y.set_shape(raw_int_fs)
    offsets = tf.concat(offsets, axis=0)
    offsets.set_shape(int_bs[:3] + [2])
    # case1: visualize optical flow: minus current position
    h_add = tf.tile(tf.reshape(tf.range(bs[1]), [1, bs[1], 1, 1]), [bs[0], 1, bs[2], 1])
    w_add = tf.tile(tf.reshape(tf.range(bs[2]), [1, 1, bs[2], 1]), [bs[0], bs[1], 1, 1])
    offsets = offsets - tf.concat([h_add, w_add], axis=3)
    # to flow image
    flow = flow_to_image_tf(offsets)
    # # case2: visualize which pixels are attended
    # flow = highlight_flow_tf(offsets * tf.cast(mask, tf.int32))
    if rate != 1:
        flow = resize(flow, scale=rate, func=tf.image.resize_nearest_neighbor)
    return y, flow
Beispiel #25
0
    def build_graph_with_losses(self,
                                batch_data,
                                batch_mask,
                                batch_guide,
                                config,
                                training=True,
                                summary=False,
                                reuse=False):
        batch_pos = batch_data / 127.5 - 1.
        # generate mask, 1 represents masked point[]
        #print(batch_data, batch_mask)
        if batch_mask is None:
            batch_mask = random_ff_mask(config)
        else:
            pass
            #batch_mask = tf.reshape(batch_mask[0], [1, *batch_mask.get_shape().as_list()[1:]])
        #print(batch_mask.shape)
        #rint()
        batch_incomplete = batch_pos * (1. - batch_mask)
        ones_x = tf.ones_like(batch_mask)[:, :, :, 0:1]
        batch_mask = ones_x * batch_mask
        batch_guide = ones_x
        x1, x2, offset_flow = self.build_inpaint_net(batch_incomplete,
                                                     batch_mask,
                                                     batch_mask,
                                                     config,
                                                     reuse=reuse,
                                                     training=training,
                                                     padding=config.PADDING)
        if config.PRETRAIN_COARSE_NETWORK:
            batch_predicted = x1
            logger.info('Set batch_predicted to x1.')
        else:
            batch_predicted = x2
            logger.info('Set batch_predicted to x2.')
        losses = {}
        # apply mask and complete image
        batch_complete = batch_predicted * batch_mask + batch_incomplete * (
            1. - batch_mask)

        # local patches
        local_patch_batch_pos = mask_patch(batch_pos, batch_mask)
        local_patch_batch_predicted = mask_patch(batch_predicted, batch_mask)
        local_patch_x1 = mask_patch(x1, batch_mask)
        local_patch_x2 = mask_patch(x2, batch_mask)
        local_patch_batch_complete = mask_patch(batch_complete, batch_mask)
        #local_patch_mask = mask_patch(mask, bbox)

        # local patch l1 loss hole+out same as partial convolution
        l1_alpha = config.COARSE_L1_ALPHA
        losses['l1_loss'] = l1_alpha * tf.reduce_mean(
            tf.abs(local_patch_batch_pos -
                   local_patch_x1))  # *spatial_discounting_mask(config))
        if not config.PRETRAIN_COARSE_NETWORK:
            losses['l1_loss'] += tf.reduce_mean(
                tf.abs(local_patch_batch_pos -
                       local_patch_x2))  # *spatial_discounting_mask(config))
        losses['ae_loss'] = l1_alpha * tf.reduce_mean(
            tf.abs(batch_pos - x1) * (1. - batch_mask))
        if not config.PRETRAIN_COARSE_NETWORK:
            losses['ae_loss'] += tf.reduce_mean(
                tf.abs(batch_pos - x2) * (1. - batch_mask))
        losses['ae_loss'] /= tf.reduce_mean(1. - batch_mask)

        if summary:
            scalar_summary('losses/l1_loss', losses['l1_loss'])
            scalar_summary('losses/ae_loss', losses['ae_loss'])
            viz_img = [batch_pos, batch_incomplete, batch_complete]
            if offset_flow is not None:
                viz_img.append(
                    resize(offset_flow,
                           scale=4,
                           func=tf.image.resize_nearest_neighbor))
            images_summary(tf.concat(viz_img, axis=2),
                           'raw_incomplete_predicted_complete',
                           config.VIZ_MAX_OUT)

        # gan
        batch_pos_neg = tf.concat([batch_pos, batch_complete], axis=0)
        if config.MASKFROMFILE:
            batch_mask_all = tf.tile(batch_mask, [2, 1, 1, 1])
            #batch_mask = tf.tile(batch_mask, [config.BATCH_SIZE, 1, 1, 1])
        else:
            batch_mask_all = tf.tile(batch_mask,
                                     [config.BATCH_SIZE * 2, 1, 1, 1])
            batch_mask = tf.tile(batch_mask, [config.BATCH_SIZE, 1, 1, 1])
        if config.GAN_WITH_MASK:
            batch_pos_neg = tf.concat([batch_pos_neg, batch_mask_all], axis=3)

        if config.GAN_WITH_GUIDE:
            batch_pos_neg = tf.concat([
                batch_pos_neg,
                tf.tile(batch_guide, [config.BATCH_SIZE * 2, 1, 1, 1])
            ],
                                      axis=3)
        #batch_pos_, batch_complete_ = tf.split(axis, value, num_split, name=None)
        # sn-pgan with gradient penalty
        if config.GAN == 'sn_pgan':
            # sn path gan
            pos_neg = self.build_sn_pgan_discriminator(batch_pos_neg,
                                                       training=training,
                                                       reuse=reuse)
            pos_global, neg_global = tf.split(pos_neg, 2)

            # wgan loss
            #g_loss_local, d_loss_local = gan_wgan_loss(pos_local, neg_local, name='gan/local_gan')
            g_loss_global, d_loss_global = gan_sn_pgan_loss(
                pos_global, neg_global, name='gan/global_gan')
            losses['g_loss'] = config.GLOBAL_WGAN_LOSS_ALPHA * g_loss_global
            losses['d_loss'] = d_loss_global
            # gp

            # Random Interpolate between true and false

            interpolates_global = random_interpolates(
                tf.concat([batch_pos, batch_mask], axis=3),
                tf.concat([batch_complete, batch_mask], axis=3))
            dout_global = self.build_sn_pgan_discriminator(interpolates_global,
                                                           reuse=True)

            # apply penalty
            penalty_global = gradients_penalty(interpolates_global,
                                               dout_global,
                                               mask=batch_mask)
            losses['gp_loss'] = config.WGAN_GP_LAMBDA * penalty_global
            #losses['d_loss'] = losses['d_loss'] + losses['gp_loss']

            if summary and not config.PRETRAIN_COARSE_NETWORK:
                gradients_summary(g_loss_global,
                                  batch_predicted,
                                  name='g_loss_global')
                scalar_summary('convergence/d_loss', losses['d_loss'])
                scalar_summary('convergence/global_d_loss', d_loss_global)
                scalar_summary('gan_sn_pgan_loss/gp_loss', losses['gp_loss'])
                scalar_summary('gan_sn_pgan_loss/gp_penalty_global',
                               penalty_global)

        if summary and not config.PRETRAIN_COARSE_NETWORK:
            # summary the magnitude of gradients from different losses w.r.t. predicted image
            gradients_summary(losses['g_loss'], batch_predicted, name='g_loss')
            gradients_summary(losses['g_loss'], x1, name='g_loss_to_x1')
            gradients_summary(losses['g_loss'], x2, name='g_loss_to_x2')
            gradients_summary(losses['l1_loss'], x1, name='l1_loss_to_x1')
            gradients_summary(losses['l1_loss'], x2, name='l1_loss_to_x2')
            gradients_summary(losses['ae_loss'], x1, name='ae_loss_to_x1')
            gradients_summary(losses['ae_loss'], x2, name='ae_loss_to_x2')
        if config.PRETRAIN_COARSE_NETWORK:
            losses['g_loss'] = 0
        else:
            losses['g_loss'] = config.GAN_LOSS_ALPHA * losses['g_loss']
        losses['g_loss'] += config.L1_LOSS_ALPHA * losses['l1_loss']
        logger.info('Set L1_LOSS_ALPHA to %f' % config.L1_LOSS_ALPHA)
        logger.info('Set GAN_LOSS_ALPHA to %f' % config.GAN_LOSS_ALPHA)
        if config.AE_LOSS:
            losses['g_loss'] += config.AE_LOSS_ALPHA * losses['ae_loss']
            logger.info('Set AE_LOSS_ALPHA to %f' % config.AE_LOSS_ALPHA)
        g_vars = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES,
                                   'inpaint_net')
        d_vars = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES,
                                   'discriminator')
        return g_vars, d_vars, losses
Beispiel #26
0
    def build_inpaint_net(self,
                          x,
                          mask,
                          config=None,
                          reuse=False,
                          training=True,
                          padding='SAME',
                          name='inpaint_net',
                          exclusionmask=None):
        """Inpaint network.

        Args:
            x: incomplete image, [-1, 1]
            mask: mask region {0, 1}
        Returns:
            [-1, 1] as predicted image
        """
        multires = config.MULTIRES
        xin = x
        offset_flow = None
        ones_x = tf.ones_like(x)[:, :, :, 0:1]
        x = tf.concat([x, ones_x, ones_x * mask], axis=3)
        hasmask = False  #TODO:  #exclusionmask is not None
        if hasmask:
            exclusionmask = tf.cast(tf.less(exclusionmask[:, :, :, 0:1], 0.5),
                                    tf.float32)
            #x = tf.concat([x, exclusionmask], axis=3)
        use_gating = config.GATING

        # two stage network
        cnum = 24 if use_gating else 32
        with tf.variable_scope(name, reuse=reuse), \
                arg_scope([gen_conv, gen_deconv],
                          training=training, padding=padding):
            # stage1
            x = gen_conv(x, cnum, 5, 1, name='conv1', gating=use_gating)
            x = gen_conv(x,
                         2 * cnum,
                         3,
                         2,
                         name='conv2_downsample',
                         gating=use_gating)
            x = gen_conv(x, 2 * cnum, 3, 1, name='conv3', gating=use_gating)
            x = gen_conv(x,
                         4 * cnum,
                         3,
                         2,
                         name='conv4_downsample',
                         gating=use_gating)
            x = gen_conv(x, 4 * cnum, 3, 1, name='conv5', gating=use_gating)
            x = gen_conv(x, 4 * cnum, 3, 1, name='conv6', gating=use_gating)
            mask_s = resize_mask_like(mask, x)
            x = gen_conv(x,
                         4 * cnum,
                         3,
                         rate=2,
                         name='conv7_atrous',
                         gating=use_gating)
            x = gen_conv(x,
                         4 * cnum,
                         3,
                         rate=4,
                         name='conv8_atrous',
                         gating=use_gating)
            x = gen_conv(x,
                         4 * cnum,
                         3,
                         rate=8,
                         name='conv9_atrous',
                         gating=use_gating)
            x = gen_conv(x,
                         4 * cnum,
                         3,
                         rate=16,
                         name='conv10_atrous',
                         gating=use_gating)
            x = gen_conv(x, 4 * cnum, 3, 1, name='conv11', gating=use_gating)
            x = gen_conv(x, 4 * cnum, 3, 1, name='conv12', gating=use_gating)
            x = gen_deconv(x,
                           2 * cnum,
                           name='conv13_upsample',
                           gating=use_gating)
            x = gen_conv(x, 2 * cnum, 3, 1, name='conv14', gating=use_gating)
            x = gen_deconv(x, cnum, name='conv15_upsample', gating=use_gating)
            x = gen_conv(x, cnum // 2, 3, 1, name='conv16', gating=use_gating)
            x = gen_conv(x, 3, 3, 1, activation=None, name='conv17')
            x = tf.clip_by_value(x, -1., 1.)
            x_stage1 = x
            # return x_stage1, None, None

            # stage2, paste result as input
            # x = tf.stop_gradient(x)
            x = x * mask + xin * (1. - mask)
            x.set_shape(xin.get_shape().as_list())
            # conv branch
            xnow = tf.concat([x, ones_x, ones_x * mask], axis=3)
            #if hasmask:
            #    xnow = tf.concat([xnow, exclusionmask], axis=3)
            x = gen_conv(xnow, cnum, 5, 1, name='xconv1', gating=use_gating)
            x = gen_conv(x,
                         cnum,
                         3,
                         2,
                         name='xconv2_downsample',
                         gating=use_gating)
            x = gen_conv(x, 2 * cnum, 3, 1, name='xconv3', gating=use_gating)
            x = gen_conv(x,
                         2 * cnum,
                         3,
                         2,
                         name='xconv4_downsample',
                         gating=use_gating)
            x = gen_conv(x, 4 * cnum, 3, 1, name='xconv5', gating=use_gating)
            x = gen_conv(x, 4 * cnum, 3, 1, name='xconv6', gating=use_gating)
            x = gen_conv(x,
                         4 * cnum,
                         3,
                         rate=2,
                         name='xconv7_atrous',
                         gating=use_gating)
            x = gen_conv(x,
                         4 * cnum,
                         3,
                         rate=4,
                         name='xconv8_atrous',
                         gating=use_gating)
            x = gen_conv(x,
                         4 * cnum,
                         3,
                         rate=8,
                         name='xconv9_atrous',
                         gating=use_gating)
            x = gen_conv(x,
                         4 * cnum,
                         3,
                         rate=16,
                         name='xconv10_atrous',
                         gating=use_gating)
            x_hallu = x
            # attention branch
            x = gen_conv(xnow, cnum, 5, 1, name='pmconv1', gating=use_gating)
            x = gen_conv(x,
                         cnum,
                         3,
                         2,
                         name='pmconv2_downsample',
                         gating=use_gating)
            x = gen_conv(x, 2 * cnum, 3, 1, name='pmconv3', gating=use_gating)
            x = gen_conv(x,
                         4 * cnum,
                         3,
                         2,
                         name='pmconv4_downsample',
                         gating=use_gating)
            x = gen_conv(x, 4 * cnum, 3, 1, name='pmconv5', gating=use_gating)
            x = gen_conv(x,
                         4 * cnum,
                         3,
                         1,
                         name='pmconv6',
                         activation=tf.nn.relu,
                         gating=use_gating)
            flows = []
            use_attentionmask = hasmask and config.ATTENTION_MASK
            if use_attentionmask:
                ex_mask_s = resize_mask_like(exclusionmask, x)
            if multires:  #scale down feature map, run contextual attention, scale up and paste inpainted region into original feature map
                logger.info('USING MULTIRES')
                logger.info('original x shape: ' + str(x.shape))
                logger.info('original mask shape: ' + str(mask_s.shape))
                x_multi = [x]
                mask_multi = [mask_s]
                if use_attentionmask:
                    exclusion_mask_multi = [ex_mask_s]
                for i in range(config.LEVELS - 1):
                    #x = gen_conv(x, 4*cnum, 3, 2, name='pyramid_downsample_'+str(i+1))
                    x = resize(x, scale=0.5)
                    x_multi.append(x)
                    mask_multi.append(resize_mask_like(mask_s, x))
                    if use_attentionmask:
                        exclusion_mask_multi.append(
                            resize_mask_like(ex_mask_s, x))
                        logger.info('exclusionmask shape: ' +
                                    str(exclusion_mask_multi[i + 1].shape))
                    logger.info('x shape: ' + str(x_multi[i + 1].shape))
                    logger.info('mask shape: ' + str(mask_multi[i + 1].shape))
                x_multi.reverse()
                mask_multi.reverse()
                if use_attentionmask:
                    exclusion_mask_multi.reverse()
                for i in range(config.LEVELS - 1):
                    if use_attentionmask:
                        totalmask = mask_multi[i] + exclusion_mask_multi[i]
                        print('total mask shape:', totalmask.shape)
                    else:
                        totalmask = tf.tile(mask_multi[i],
                                            [config.BATCH_SIZE, 1, 1, 1])
                    x, flow = contextual_attention(x,
                                                   x,
                                                   totalmask,
                                                   ksize=config.PATCH_KSIZE,
                                                   stride=config.PATCH_STRIDE,
                                                   rate=config.PATCH_RATE)
                    #x, flow = contextual_attention(x, x, mask_multi[i], ksize=3, stride=1, rate=1)
                    flows.append(flow)
                    x = resize(
                        x, scale=2
                    )  #TODO: look into using deconv instead of just upsampling
                    x = x * mask_multi[i + 1] + x_multi[i + 1] * (
                        1. - mask_multi[i + 1])
                    logger.info('upsampled x shape: ' + str(x.shape))

            x, offset_flow = contextual_attention(
                x,
                x,
                tf.tile(mask_s, [config.BATCH_SIZE, 1, 1, 1])
                if not use_attentionmask else mask_s + ex_mask_s,
                ksize=config.PATCH_KSIZE,
                stride=config.PATCH_STRIDE,
                rate=config.PATCH_RATE)
            flows.append(offset_flow)
            x = gen_conv(x, 4 * cnum, 3, 1, name='pmconv9', gating=use_gating)
            x = gen_conv(x, 4 * cnum, 3, 1, name='pmconv10', gating=use_gating)
            pm = x
            x = tf.concat([x_hallu, pm], axis=3)  #join branches together

            x = gen_conv(x,
                         4 * cnum,
                         3,
                         1,
                         name='allconv11',
                         gating=use_gating)
            x = gen_conv(x,
                         4 * cnum,
                         3,
                         1,
                         name='allconv12',
                         gating=use_gating)
            x = gen_deconv(x,
                           2 * cnum,
                           name='allconv13_upsample',
                           gating=use_gating)
            x = gen_conv(x,
                         2 * cnum,
                         3,
                         1,
                         name='allconv14',
                         gating=use_gating)
            x = gen_deconv(x,
                           cnum,
                           name='allconv15_upsample',
                           gating=use_gating)
            x = gen_conv(x,
                         cnum // 2,
                         3,
                         1,
                         name='allconv16',
                         gating=use_gating)
            x = gen_conv(x, 3, 3, 1, activation=None, name='allconv17')
            x_stage2 = tf.clip_by_value(x, -1., 1.)
        return x_stage1, x_stage2, flows
Beispiel #27
0
    def build_graph_with_losses(self,
                                batch_data,
                                config,
                                training=True,
                                summary=False,
                                reuse=False,
                                exclusionmask=None,
                                mask=None):
        batch_pos = batch_data / 127.5 - 1.
        # generate mask, 1 represents masked point
        use_local_patch = False
        if mask is None:
            bbox = random_bbox(config)
            mask = bbox2mask(bbox, config, name='mask_c')
            if config.GAN == 'wgan_gp':
                use_local_patch = True
        else:
            #bbox = (0, 0, config.IMG_SHAPES[0], config.IMG_SHAPES[1])
            mask = tf.cast(tf.less(0.5, mask[:, :, :, 0:1]), tf.float32)
            if config.INVERT_MASK:
                mask = 1 - mask

        batch_incomplete = batch_pos * (1. - mask)

        if exclusionmask is not None:
            if config.INVERT_EXCLUSIONMASK:
                loss_mask = tf.cast(tf.less(0.5, exclusionmask[:, :, :, 0:1]),
                                    tf.float32)  #keep white parts
            else:
                loss_mask = tf.cast(tf.less(exclusionmask[:, :, :, 0:1], 0.5),
                                    tf.float32)  #keep black parts
            batch_incomplete = batch_incomplete * loss_mask
            batch_pos = batch_pos * loss_mask

        x1, x2, offset_flow = self.build_inpaint_net(
            batch_incomplete,
            mask,
            config,
            reuse=reuse,
            training=training,
            padding=config.PADDING,
            exclusionmask=exclusionmask)
        if config.PRETRAIN_COARSE_NETWORK:
            batch_predicted = x1
            logger.info('Set batch_predicted to x1.')
        else:
            batch_predicted = x2
            logger.info('Set batch_predicted to x2.')
        losses = {}
        # apply mask and complete image
        batch_complete = batch_predicted * mask + batch_incomplete * (1. -
                                                                      mask)
        if exclusionmask is not None:
            batch_complete = batch_complete * loss_mask

        l1_alpha = config.COARSE_L1_ALPHA
        # local patches
        if use_local_patch:
            local_patch_batch_pos = local_patch(batch_pos, bbox)
            local_patch_batch_predicted = local_patch(batch_predicted, bbox)
            local_patch_x1 = local_patch(x1, bbox)
            local_patch_x2 = local_patch(x2, bbox)
            local_patch_batch_complete = local_patch(batch_complete, bbox)
            local_patch_mask = local_patch(mask, bbox)

            losses['l1_loss'] = l1_alpha * tf.reduce_mean(
                tf.abs(local_patch_batch_pos - local_patch_x1) *
                spatial_discounting_mask(config) *
                (loss_mask if exclusionmask is not None else 1))
            if not config.PRETRAIN_COARSE_NETWORK:
                losses['l1_loss'] += tf.reduce_mean(
                    tf.abs(local_patch_batch_pos - local_patch_x2) *
                    spatial_discounting_mask(config) *
                    (loss_mask if exclusionmask is not None else 1))

            losses['ae_loss'] = l1_alpha * tf.reduce_mean(
                tf.abs(batch_pos - x1) * (1. - mask) *
                (loss_mask if exclusionmask is not None else 1))
            if not config.PRETRAIN_COARSE_NETWORK:
                losses['ae_loss'] += tf.reduce_mean(
                    tf.abs(batch_pos - x2) * (1. - mask) *
                    (loss_mask if exclusionmask is not None else 1))
            losses['ae_loss'] /= tf.reduce_mean(1. - mask)
        else:
            losses['l1_loss'] = l1_alpha * tf.reduce_mean(
                tf.abs(batch_pos - x1) *
                (loss_mask if exclusionmask is not None else 1))
            if not config.PRETRAIN_COARSE_NETWORK:
                losses['l1_loss'] += tf.reduce_mean(
                    tf.abs(batch_pos - x2) *
                    (loss_mask if exclusionmask is not None else 1))

        if summary:
            scalar_summary('losses/l1_loss', losses['l1_loss'])
            if use_local_patch:
                scalar_summary('losses/ae_loss', losses['ae_loss'])
            img_size = [dim for dim in batch_incomplete.shape]
            img_size[2] = 5
            border = tf.zeros(tf.TensorShape(img_size))
            viz_img = [
                batch_pos, border, batch_incomplete, border, batch_complete,
                border
            ]
            if not config.PRETRAIN_COARSE_NETWORK:
                batch_complete_coarse = x1 * mask + batch_incomplete * (1. -
                                                                        mask)
                viz_img.append(batch_complete_coarse)
                viz_img.append(border)
            if offset_flow is not None:
                scale = 2 << len(offset_flow)
                for flow in offset_flow:
                    viz_img.append(
                        resize(flow,
                               scale=scale,
                               func=tf.image.resize_nearest_neighbor))
                    viz_img.append(border)
                    scale >>= 1
            images_summary(tf.concat(viz_img, axis=2),
                           'raw_incomplete_predicted_complete',
                           config.VIZ_MAX_OUT)

        # gan
        batch_pos_neg = tf.concat([batch_pos, batch_complete], axis=0)
        # local deterministic patch
        if config.GAN_WITH_MASK:
            batch_pos_neg = tf.concat([
                batch_pos_neg,
                tf.tile(mask, [config.BATCH_SIZE * 2, 1, 1, 1])
            ],
                                      axis=3)
        # wgan with gradient penalty
        if config.GAN == 'wgan_gp':
            if not use_local_patch:
                raise Exception('wgan_gp requires global and local patch')
            local_patch_batch_pos_neg = tf.concat(
                [local_patch_batch_pos, local_patch_batch_complete], 0)
            # seperate gan
            pos_neg_local, pos_neg_global = self.build_wgan_discriminator(
                local_patch_batch_pos_neg,
                batch_pos_neg,
                training=training,
                reuse=reuse)
            pos_local, neg_local = tf.split(pos_neg_local, 2)
            pos_global, neg_global = tf.split(pos_neg_global, 2)
            # wgan loss
            g_loss_local, d_loss_local = gan_wgan_loss(pos_local,
                                                       neg_local,
                                                       name='gan/local_gan')
            g_loss_global, d_loss_global = gan_wgan_loss(pos_global,
                                                         neg_global,
                                                         name='gan/global_gan')
            losses[
                'g_loss'] = config.GLOBAL_WGAN_LOSS_ALPHA * g_loss_global + g_loss_local
            losses['d_loss'] = d_loss_global + d_loss_local
            # gp
            interpolates_local = random_interpolates(
                local_patch_batch_pos, local_patch_batch_complete)
            interpolates_global = random_interpolates(batch_pos,
                                                      batch_complete)
            dout_local, dout_global = self.build_wgan_discriminator(
                interpolates_local, interpolates_global, reuse=True)
            # apply penalty
            penalty_local = gradients_penalty(interpolates_local,
                                              dout_local,
                                              mask=local_patch_mask)
            penalty_global = gradients_penalty(interpolates_global,
                                               dout_global,
                                               mask=mask)
            losses['gp_loss'] = config.WGAN_GP_LAMBDA * (penalty_local +
                                                         penalty_global)
            losses['d_loss'] = losses['d_loss'] + losses['gp_loss']
            if summary and not config.PRETRAIN_COARSE_NETWORK:
                gradients_summary(g_loss_local,
                                  batch_predicted,
                                  name='g_loss_local')
                gradients_summary(g_loss_global,
                                  batch_predicted,
                                  name='g_loss_global')
                scalar_summary('convergence/d_loss', losses['d_loss'])
                scalar_summary('convergence/local_d_loss', d_loss_local)
                scalar_summary('convergence/global_d_loss', d_loss_global)
                scalar_summary('gan_wgan_loss/gp_loss', losses['gp_loss'])
                scalar_summary('gan_wgan_loss/gp_penalty_local', penalty_local)
                scalar_summary('gan_wgan_loss/gp_penalty_global',
                               penalty_global)
        elif config.GAN == 'sngan':
            if use_local_patch:
                raise Exception(
                    'sngan incompatible with global and local patch')
            pos_neg = self.build_sngan_discriminator(batch_pos_neg,
                                                     name='discriminator',
                                                     reuse=reuse)
            pos, neg = tf.split(pos_neg, 2)
            g_loss, d_loss = gan_hinge_loss(pos, neg)
            losses['g_loss'] = g_loss
            losses['d_loss'] = d_loss
            if summary and not config.PRETRAIN_COARSE_NETWORK:
                gradients_summary(g_loss, batch_predicted, name='g_loss')
                scalar_summary('convergence/d_loss', losses['d_loss'])
        else:
            losses['g_loss'] = 0

        if summary and not config.PRETRAIN_COARSE_NETWORK:
            # summary the magnitude of gradients from different losses w.r.t. predicted image
            gradients_summary(losses['g_loss'], batch_predicted, name='g_loss')
            gradients_summary(losses['g_loss'], x1, name='g_loss_to_x1')
            gradients_summary(losses['g_loss'], x2, name='g_loss_to_x2')
            gradients_summary(losses['l1_loss'], x1, name='l1_loss_to_x1')
            gradients_summary(losses['l1_loss'], x2, name='l1_loss_to_x2')
            if use_local_patch:
                gradients_summary(losses['ae_loss'], x1, name='ae_loss_to_x1')
                gradients_summary(losses['ae_loss'], x2, name='ae_loss_to_x2')
        if config.PRETRAIN_COARSE_NETWORK:
            losses['g_loss'] = 0
        else:
            losses['g_loss'] = config.GAN_LOSS_ALPHA * losses['g_loss']
        losses['g_loss'] += config.L1_LOSS_ALPHA * losses['l1_loss']
        logger.info('Set L1_LOSS_ALPHA to %f' % config.L1_LOSS_ALPHA)
        logger.info('Set GAN_LOSS_ALPHA to %f' % config.GAN_LOSS_ALPHA)
        if config.AE_LOSS and use_local_patch:
            losses['g_loss'] += config.AE_LOSS_ALPHA * losses['ae_loss']
            logger.info('Set AE_LOSS_ALPHA to %f' % config.AE_LOSS_ALPHA)
        g_vars = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES,
                                   'inpaint_net')
        d_vars = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES,
                                   'discriminator')
        return g_vars, d_vars, losses