예제 #1
0
    def build_graph_with_losses(
            self, FLAGS, batch_data, training=True, summary=False,
            reuse=False):
        if FLAGS.guided:
            batch_data, edge = batch_data
            edge = edge[:, :, :, 0:1] / 255.
            edge = tf.cast(edge > FLAGS.edge_threshold, tf.float32)
        batch_pos = batch_data / 127.5 - 1.
        # generate mask, 1 represents masked point
        bbox = random_bbox(FLAGS)
        regular_mask = bbox2mask(FLAGS, bbox, name='mask_c')
        irregular_mask = brush_stroke_mask(FLAGS, name='mask_c')
        mask = tf.cast(
            tf.logical_or(
                tf.cast(irregular_mask, tf.bool),
                tf.cast(regular_mask, tf.bool),
            ),
            tf.float32
        )

        batch_incomplete = batch_pos*(1.-mask)
        if FLAGS.guided:
            edge = edge * mask
            xin = tf.concat([batch_incomplete, edge], axis=3)
        else:
            xin = batch_incomplete
        x1, x2, offset_flow = self.build_inpaint_net(
            xin, mask, reuse=reuse, training=training,
            padding=FLAGS.padding)
        batch_predicted = x2
        losses = {}
        # apply mask and complete image
        batch_complete = batch_predicted*mask + batch_incomplete*(1.-mask)
        # local patches
        losses['ae_loss'] = FLAGS.l1_loss_alpha * tf.reduce_mean(tf.abs(batch_pos - x1))
        losses['ae_loss'] += FLAGS.l1_loss_alpha * tf.reduce_mean(tf.abs(batch_pos - x2))
        if summary:
            scalar_summary('losses/ae_loss', losses['ae_loss'])
            if FLAGS.guided:
                viz_img = [
                    batch_pos,
                    batch_incomplete + edge,
                    batch_complete]
            else:
                viz_img = [batch_pos, batch_incomplete, batch_complete]
            if offset_flow is not None:
                viz_img.append(
                    resize(offset_flow, scale=4,
                           func=tf.image.resize_bilinear))
            images_summary(
                tf.concat(viz_img, axis=2),
                'raw_incomplete_predicted_complete', FLAGS.viz_max_out)

        # gan
        batch_pos_neg = tf.concat([batch_pos, batch_complete], axis=0)
        if FLAGS.gan_with_mask:
            batch_pos_neg = tf.concat([batch_pos_neg, tf.tile(mask, [FLAGS.batch_size*2, 1, 1, 1])], axis=3)
        if FLAGS.guided:
            # conditional GANs
            batch_pos_neg = tf.concat([batch_pos_neg, tf.tile(edge, [2, 1, 1, 1])], axis=3)
        # wgan with gradient penalty
        if FLAGS.gan == 'sngan':
            pos_neg = self.build_gan_discriminator(batch_pos_neg, training=training, reuse=reuse)
            pos, neg = tf.split(pos_neg, 2)
            g_loss, d_loss = gan_hinge_loss(pos, neg)
            losses['g_loss'] = g_loss
            losses['d_loss'] = d_loss
        else:
            raise NotImplementedError('{} not implemented.'.format(FLAGS.gan))
        if summary:
            # summary the magnitude of gradients from different losses w.r.t. predicted image
            gradients_summary(losses['g_loss'], batch_predicted, name='g_loss')
            gradients_summary(losses['g_loss'], x2, name='g_loss_to_x2')
            # gradients_summary(losses['ae_loss'], x1, name='ae_loss_to_x1')
            gradients_summary(losses['ae_loss'], x2, name='ae_loss_to_x2')
        losses['g_loss'] = FLAGS.gan_loss_alpha * losses['g_loss']
        if FLAGS.ae_loss:
            losses['g_loss'] += losses['ae_loss']
        g_vars = tf.get_collection(
            tf.GraphKeys.TRAINABLE_VARIABLES, 'inpaint_net')
        d_vars = tf.get_collection(
            tf.GraphKeys.TRAINABLE_VARIABLES, 'discriminator')
        return g_vars, d_vars, losses
    def build_graph_with_losses(self, batch_data, config, training=True,
                                summary=False, reuse=False):
        batch_pos = batch_data / 127.5 - 1.
        # generate mask, 1 represents masked point
        bbox = random_bbox(config)
        mask = bbox2mask(bbox, config, name='mask_c')
        batch_incomplete = batch_pos*(1.-mask)
        x1, x2, offset_flow = self.build_inpaint_net(
            batch_incomplete, mask, config, reuse=reuse, training=training,
            padding=config.PADDING)
        if config.PRETRAIN_COARSE_NETWORK:
            batch_predicted = x1
            logger.info('Set batch_predicted to x1.')
        else:
            batch_predicted = x2
            logger.info('Set batch_predicted to x2.')
        losses = {}
        # apply mask and complete image
        batch_complete = batch_predicted*mask + batch_incomplete*(1.-mask)
        # local patches
        local_patch_batch_pos = local_patch(batch_pos, bbox)
        local_patch_batch_predicted = local_patch(batch_predicted, bbox)
        local_patch_x1 = local_patch(x1, bbox)
        local_patch_x2 = local_patch(x2, bbox)
        local_patch_batch_complete = local_patch(batch_complete, bbox)
        local_patch_mask = local_patch(mask, bbox)
        l1_alpha = config.COARSE_L1_ALPHA
        losses['l1_loss'] = l1_alpha * tf.reduce_mean(tf.abs(local_patch_batch_pos - local_patch_x1)*spatial_discounting_mask(config))
        if not config.PRETRAIN_COARSE_NETWORK:
            losses['l1_loss'] += tf.reduce_mean(tf.abs(local_patch_batch_pos - local_patch_x2)*spatial_discounting_mask(config))
        losses['ae_loss'] = l1_alpha * tf.reduce_mean(tf.abs(batch_pos - x1) * (1.-mask))
        if not config.PRETRAIN_COARSE_NETWORK:
            losses['ae_loss'] += tf.reduce_mean(tf.abs(batch_pos - x2) * (1.-mask))
        losses['ae_loss'] /= tf.reduce_mean(1.-mask)
        if summary:
            scalar_summary('losses/l1_loss', losses['l1_loss'])
            scalar_summary('losses/ae_loss', losses['ae_loss'])
            viz_img = [batch_pos, batch_incomplete, batch_complete]
            if offset_flow is not None:
                viz_img.append(
                    resize(offset_flow, scale=4,
                           func=tf.image.resize_nearest_neighbor))
            images_summary(
                tf.concat(viz_img, axis=2),
                'raw_incomplete_predicted_complete', config.VIZ_MAX_OUT)

        # gan
        batch_pos_neg = tf.concat([batch_pos, batch_complete], axis=0)
        # local deterministic patch
        local_patch_batch_pos_neg = tf.concat([local_patch_batch_pos, local_patch_batch_complete], 0)
        if config.GAN_WITH_MASK:
            batch_pos_neg = tf.concat([batch_pos_neg, tf.tile(mask, [config.BATCH_SIZE*2, 1, 1, 1])], axis=3)
        # wgan with gradient penalty
        if config.GAN == 'wgan_gp':
            # seperate gan
            pos_neg_local, pos_neg_global = self.build_wgan_discriminator(local_patch_batch_pos_neg, batch_pos_neg, training=training, reuse=reuse)
            pos_local, neg_local = tf.split(pos_neg_local, 2)
            pos_global, neg_global = tf.split(pos_neg_global, 2)
            # wgan loss
            g_loss_local, d_loss_local = gan_wgan_loss(pos_local, neg_local, name='gan/local_gan')
            g_loss_global, d_loss_global = gan_wgan_loss(pos_global, neg_global, name='gan/global_gan')
            losses['g_loss'] = config.GLOBAL_WGAN_LOSS_ALPHA * g_loss_global + g_loss_local
            losses['d_loss'] = d_loss_global + d_loss_local
            # gp
            interpolates_local = random_interpolates(local_patch_batch_pos, local_patch_batch_complete)
            interpolates_global = random_interpolates(batch_pos, batch_complete)
            dout_local, dout_global = self.build_wgan_discriminator(
                interpolates_local, interpolates_global, reuse=True)
            # apply penalty
            penalty_local = gradients_penalty(interpolates_local, dout_local, mask=local_patch_mask)
            penalty_global = gradients_penalty(interpolates_global, dout_global, mask=mask)
            losses['gp_loss'] = config.WGAN_GP_LAMBDA * (penalty_local + penalty_global)
            losses['d_loss'] = losses['d_loss'] + losses['gp_loss']
            if summary and not config.PRETRAIN_COARSE_NETWORK:
                gradients_summary(g_loss_local, batch_predicted, name='g_loss_local')
                gradients_summary(g_loss_global, batch_predicted, name='g_loss_global')
                scalar_summary('convergence/d_loss', losses['d_loss'])
                scalar_summary('convergence/local_d_loss', d_loss_local)
                scalar_summary('convergence/global_d_loss', d_loss_global)
                scalar_summary('gan_wgan_loss/gp_loss', losses['gp_loss'])
                scalar_summary('gan_wgan_loss/gp_penalty_local', penalty_local)
                scalar_summary('gan_wgan_loss/gp_penalty_global', penalty_global)

        if summary and not config.PRETRAIN_COARSE_NETWORK:
            # summary the magnitude of gradients from different losses w.r.t. predicted image
            gradients_summary(losses['g_loss'], batch_predicted, name='g_loss')
            gradients_summary(losses['g_loss'], x1, name='g_loss_to_x1')
            gradients_summary(losses['g_loss'], x2, name='g_loss_to_x2')
            gradients_summary(losses['l1_loss'], x1, name='l1_loss_to_x1')
            gradients_summary(losses['l1_loss'], x2, name='l1_loss_to_x2')
            gradients_summary(losses['ae_loss'], x1, name='ae_loss_to_x1')
            gradients_summary(losses['ae_loss'], x2, name='ae_loss_to_x2')
        if config.PRETRAIN_COARSE_NETWORK:
            losses['g_loss'] = 0
        else:
            losses['g_loss'] = config.GAN_LOSS_ALPHA * losses['g_loss']
        losses['g_loss'] += config.L1_LOSS_ALPHA * losses['l1_loss']
        logger.info('Set L1_LOSS_ALPHA to %f' % config.L1_LOSS_ALPHA)
        logger.info('Set GAN_LOSS_ALPHA to %f' % config.GAN_LOSS_ALPHA)
        if config.AE_LOSS:
            losses['g_loss'] += config.AE_LOSS_ALPHA * losses['ae_loss']
            logger.info('Set AE_LOSS_ALPHA to %f' % config.AE_LOSS_ALPHA)
        g_vars = tf.get_collection(
            tf.GraphKeys.TRAINABLE_VARIABLES, 'inpaint_net')
        d_vars = tf.get_collection(
            tf.GraphKeys.TRAINABLE_VARIABLES, 'discriminator')
        return g_vars, d_vars, losses
    def build_graph_with_losses(self,
                                batch_data,
                                config,
                                training=True,
                                summary=False,
                                reuse=False):
        batch_pos = batch_data / 127.5 - 1.

        mask = bbox2mask(config, name='mask_c')
        batch_incomplete = batch_pos * (1. - mask)
        x1, x2, offset_flow = self.build_inpaint_net(batch_incomplete,
                                                     mask,
                                                     config,
                                                     reuse=reuse,
                                                     training=training,
                                                     padding=config.PADDING)
        if config.PRETRAIN_COARSE_NETWORK:
            batch_predicted = x1
            logger.info('Set batch_predicted to x1.')
        else:
            batch_predicted = x2
            logger.info('Set batch_predicted to x2.')
        losses = {}
        batch_complete = batch_predicted * mask + batch_incomplete * (1. -
                                                                      mask)
        local_patch_batch_pos = local_patch(batch_pos, mask)
        local_patch_batch_predicted = local_patch(batch_predicted, mask)
        local_patch_x1 = local_patch(x1, mask)
        local_patch_x2 = local_patch(x2, mask)
        local_patch_batch_complete = local_patch(batch_complete, mask)
        l1_alpha = config.COARSE_L1_ALPHA
        losses['l1_loss'] = l1_alpha * tf.reduce_mean(
            tf.abs(local_patch_batch_pos -
                   local_patch_x1))  #*spatial_discounting_mask(config))
        if not config.PRETRAIN_COARSE_NETWORK:
            losses['l1_loss'] += tf.reduce_mean(
                tf.abs(local_patch_batch_pos -
                       local_patch_x2))  #*spatial_discounting_mask(config))
        losses['ae_loss'] = l1_alpha * tf.reduce_mean(
            tf.abs(batch_pos - x1) * (1. - mask))
        if not config.PRETRAIN_COARSE_NETWORK:
            losses['ae_loss'] += tf.reduce_mean(
                tf.abs(batch_pos - x2) * (1. - mask))
        losses['ae_loss'] /= tf.reduce_mean(1. - mask)
        if summary:
            scalar_summary('losses/l1_loss', losses['l1_loss'])
            scalar_summary('losses/ae_loss', losses['ae_loss'])
            viz_img = [batch_pos, batch_incomplete, batch_complete]
            if offset_flow is not None:
                viz_img.append(
                    resize(offset_flow,
                           scale=4,
                           func=tf.image.resize_nearest_neighbor))
            images_summary(tf.concat(viz_img, axis=2),
                           'raw_incomplete_predicted_complete',
                           config.VIZ_MAX_OUT)

        batch_pos_neg = tf.concat([batch_pos, batch_complete], axis=0)
        local_patch_batch_pos_neg = tf.concat(
            [local_patch_batch_pos, local_patch_batch_complete], 0)
        if config.GAN_WITH_MASK:
            batch_pos_neg = tf.concat([
                batch_pos_neg,
                tf.tile(mask, [config.BATCH_SIZE * 2, 1, 1, 1])
            ],
                                      axis=3)

        if config.GAN == 'snpatch_gan':
            pos_neg = self.build_SNGAN_discriminator(local_patch_batch_pos_neg,
                                                     training=training,
                                                     reuse=reuse)
            pos, neg = tf.split(pos_neg, 2)
            sn_gloss, sn_dloss = self.gan_hinge_loss(pos,
                                                     neg,
                                                     name="gan/hinge_loss")
            losses['g_loss'] = config.GLOBAL_WGAN_LOSS_ALPHA * sn_gloss
            losses['d_loss'] = sn_dloss
            interpolates = random_interpolates(a1, a2)
            dout = self.build_SNGAN_discriminator(interpolates, reuse=True)
            penalty = gradients_penalty(interpolates, dout, mask=mask)
            losses['gp_loss'] = config.WGAN_GP_LAMBDA * penalty
            losses['d_loss'] = losses['d_loss'] + losses['gp_loss']
            if summary and not config.PRETRAIN_COARSE_NETWORK:
                gradients_summary(sn_gloss,
                                  batch_predicted,
                                  name='g_loss_local')
                scalar_summary('convergence/d_loss', losses['d_loss'])
                scalar_summary('convergence/local_d_loss', sn_dloss)
                scalar_summary('gan_wgan_loss/gp_loss', losses['gp_loss'])
                scalar_summary('gan_wgan_loss/gp_penalty_local', penalty)

        if summary and not config.PRETRAIN_COARSE_NETWORK:
            gradients_summary(losses['g_loss'], batch_predicted, name='g_loss')
            gradients_summary(losses['g_loss'], x1, name='g_loss_to_x1')
            gradients_summary(losses['g_loss'], x2, name='g_loss_to_x2')
            gradients_summary(losses['l1_loss'], x1, name='l1_loss_to_x1')
            gradients_summary(losses['l1_loss'], x2, name='l1_loss_to_x2')
            gradients_summary(losses['ae_loss'], x1, name='ae_loss_to_x1')
            gradients_summary(losses['ae_loss'], x2, name='ae_loss_to_x2')
        if config.PRETRAIN_COARSE_NETWORK:
            losses['g_loss'] = 0
        else:
            losses['g_loss'] = config.GAN_LOSS_ALPHA * losses['g_loss']
        losses['g_loss'] += config.L1_LOSS_ALPHA * losses['l1_loss']
        logger.info('Set L1_LOSS_ALPHA to %f' % config.L1_LOSS_ALPHA)
        logger.info('Set GAN_LOSS_ALPHA to %f' % config.GAN_LOSS_ALPHA)
        if config.AE_LOSS:
            losses['g_loss'] += config.AE_LOSS_ALPHA * losses['ae_loss']
            logger.info('Set AE_LOSS_ALPHA to %f' % config.AE_LOSS_ALPHA)
        g_vars = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES,
                                   'inpaint_net')
        d_vars = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES,
                                   'discriminator')
        return g_vars, d_vars, losses
예제 #4
0
    def build_graph_with_losses(self,
                                batch_data,
                                batch_mask,
                                batch_guide,
                                config,
                                training=True,
                                summary=False,
                                reuse=False):
        batch_pos = batch_data / 127.5 - 1.
        # generate mask, 1 represents masked point[]
        #print(batch_data, batch_mask)
        if batch_mask is None:
            batch_mask = random_ff_mask(config)
        else:
            pass
            #batch_mask = tf.reshape(batch_mask[0], [1, *batch_mask.get_shape().as_list()[1:]])
        #print(batch_mask.shape)
        #rint()
        batch_incomplete = batch_pos * (1. - batch_mask)
        ones_x = tf.ones_like(batch_mask)[:, :, :, 0:1]
        batch_mask = ones_x * batch_mask
        batch_guide = ones_x
        x1, x2, offset_flow = self.build_inpaint_net(batch_incomplete,
                                                     batch_mask,
                                                     batch_mask,
                                                     config,
                                                     reuse=reuse,
                                                     training=training,
                                                     padding=config.PADDING)
        if config.PRETRAIN_COARSE_NETWORK:
            batch_predicted = x1
            logger.info('Set batch_predicted to x1.')
        else:
            batch_predicted = x2
            logger.info('Set batch_predicted to x2.')
        losses = {}
        # apply mask and complete image
        batch_complete = batch_predicted * batch_mask + batch_incomplete * (
            1. - batch_mask)

        # local patches
        local_patch_batch_pos = mask_patch(batch_pos, batch_mask)
        local_patch_batch_predicted = mask_patch(batch_predicted, batch_mask)
        local_patch_x1 = mask_patch(x1, batch_mask)
        local_patch_x2 = mask_patch(x2, batch_mask)
        local_patch_batch_complete = mask_patch(batch_complete, batch_mask)
        #local_patch_mask = mask_patch(mask, bbox)

        # local patch l1 loss hole+out same as partial convolution
        l1_alpha = config.COARSE_L1_ALPHA
        losses['l1_loss'] = l1_alpha * tf.reduce_mean(
            tf.abs(local_patch_batch_pos -
                   local_patch_x1))  # *spatial_discounting_mask(config))
        if not config.PRETRAIN_COARSE_NETWORK:
            losses['l1_loss'] += tf.reduce_mean(
                tf.abs(local_patch_batch_pos -
                       local_patch_x2))  # *spatial_discounting_mask(config))
        losses['ae_loss'] = l1_alpha * tf.reduce_mean(
            tf.abs(batch_pos - x1) * (1. - batch_mask))
        if not config.PRETRAIN_COARSE_NETWORK:
            losses['ae_loss'] += tf.reduce_mean(
                tf.abs(batch_pos - x2) * (1. - batch_mask))
        losses['ae_loss'] /= tf.reduce_mean(1. - batch_mask)

        if summary:
            scalar_summary('losses/l1_loss', losses['l1_loss'])
            scalar_summary('losses/ae_loss', losses['ae_loss'])
            viz_img = [batch_pos, batch_incomplete, batch_complete]
            if offset_flow is not None:
                viz_img.append(
                    resize(offset_flow,
                           scale=4,
                           func=tf.image.resize_nearest_neighbor))
            images_summary(tf.concat(viz_img, axis=2),
                           'raw_incomplete_predicted_complete',
                           config.VIZ_MAX_OUT)

        # gan
        batch_pos_neg = tf.concat([batch_pos, batch_complete], axis=0)
        if config.MASKFROMFILE:
            batch_mask_all = tf.tile(batch_mask, [2, 1, 1, 1])
            #batch_mask = tf.tile(batch_mask, [config.BATCH_SIZE, 1, 1, 1])
        else:
            batch_mask_all = tf.tile(batch_mask,
                                     [config.BATCH_SIZE * 2, 1, 1, 1])
            batch_mask = tf.tile(batch_mask, [config.BATCH_SIZE, 1, 1, 1])
        if config.GAN_WITH_MASK:
            batch_pos_neg = tf.concat([batch_pos_neg, batch_mask_all], axis=3)

        if config.GAN_WITH_GUIDE:
            batch_pos_neg = tf.concat([
                batch_pos_neg,
                tf.tile(batch_guide, [config.BATCH_SIZE * 2, 1, 1, 1])
            ],
                                      axis=3)
        #batch_pos_, batch_complete_ = tf.split(axis, value, num_split, name=None)
        # sn-pgan with gradient penalty
        if config.GAN == 'sn_pgan':
            # sn path gan
            pos_neg = self.build_sn_pgan_discriminator(batch_pos_neg,
                                                       training=training,
                                                       reuse=reuse)
            pos_global, neg_global = tf.split(pos_neg, 2)

            # wgan loss
            #g_loss_local, d_loss_local = gan_wgan_loss(pos_local, neg_local, name='gan/local_gan')
            g_loss_global, d_loss_global = gan_sn_pgan_loss(
                pos_global, neg_global, name='gan/global_gan')
            losses['g_loss'] = config.GLOBAL_WGAN_LOSS_ALPHA * g_loss_global
            losses['d_loss'] = d_loss_global
            # gp

            # Random Interpolate between true and false

            interpolates_global = random_interpolates(
                tf.concat([batch_pos, batch_mask], axis=3),
                tf.concat([batch_complete, batch_mask], axis=3))
            dout_global = self.build_sn_pgan_discriminator(interpolates_global,
                                                           reuse=True)

            # apply penalty
            penalty_global = gradients_penalty(interpolates_global,
                                               dout_global,
                                               mask=batch_mask)
            losses['gp_loss'] = config.WGAN_GP_LAMBDA * penalty_global
            #losses['d_loss'] = losses['d_loss'] + losses['gp_loss']

            if summary and not config.PRETRAIN_COARSE_NETWORK:
                gradients_summary(g_loss_global,
                                  batch_predicted,
                                  name='g_loss_global')
                scalar_summary('convergence/d_loss', losses['d_loss'])
                scalar_summary('convergence/global_d_loss', d_loss_global)
                scalar_summary('gan_sn_pgan_loss/gp_loss', losses['gp_loss'])
                scalar_summary('gan_sn_pgan_loss/gp_penalty_global',
                               penalty_global)

        if summary and not config.PRETRAIN_COARSE_NETWORK:
            # summary the magnitude of gradients from different losses w.r.t. predicted image
            gradients_summary(losses['g_loss'], batch_predicted, name='g_loss')
            gradients_summary(losses['g_loss'], x1, name='g_loss_to_x1')
            gradients_summary(losses['g_loss'], x2, name='g_loss_to_x2')
            gradients_summary(losses['l1_loss'], x1, name='l1_loss_to_x1')
            gradients_summary(losses['l1_loss'], x2, name='l1_loss_to_x2')
            gradients_summary(losses['ae_loss'], x1, name='ae_loss_to_x1')
            gradients_summary(losses['ae_loss'], x2, name='ae_loss_to_x2')
        if config.PRETRAIN_COARSE_NETWORK:
            losses['g_loss'] = 0
        else:
            losses['g_loss'] = config.GAN_LOSS_ALPHA * losses['g_loss']
        losses['g_loss'] += config.L1_LOSS_ALPHA * losses['l1_loss']
        logger.info('Set L1_LOSS_ALPHA to %f' % config.L1_LOSS_ALPHA)
        logger.info('Set GAN_LOSS_ALPHA to %f' % config.GAN_LOSS_ALPHA)
        if config.AE_LOSS:
            losses['g_loss'] += config.AE_LOSS_ALPHA * losses['ae_loss']
            logger.info('Set AE_LOSS_ALPHA to %f' % config.AE_LOSS_ALPHA)
        g_vars = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES,
                                   'inpaint_net')
        d_vars = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES,
                                   'discriminator')
        return g_vars, d_vars, losses
예제 #5
0
    def build_graph_with_losses(self,
                                batch_data,
                                config,
                                training=True,
                                summary=False,
                                reuse=False,
                                exclusionmask=None,
                                mask=None):
        batch_pos = batch_data / 127.5 - 1.
        # generate mask, 1 represents masked point
        use_local_patch = False
        if mask is None:
            bbox = random_bbox(config)
            mask = bbox2mask(bbox, config, name='mask_c')
            if config.GAN == 'wgan_gp':
                use_local_patch = True
        else:
            #bbox = (0, 0, config.IMG_SHAPES[0], config.IMG_SHAPES[1])
            mask = tf.cast(tf.less(0.5, mask[:, :, :, 0:1]), tf.float32)
            if config.INVERT_MASK:
                mask = 1 - mask

        batch_incomplete = batch_pos * (1. - mask)

        if exclusionmask is not None:
            if config.INVERT_EXCLUSIONMASK:
                loss_mask = tf.cast(tf.less(0.5, exclusionmask[:, :, :, 0:1]),
                                    tf.float32)  #keep white parts
            else:
                loss_mask = tf.cast(tf.less(exclusionmask[:, :, :, 0:1], 0.5),
                                    tf.float32)  #keep black parts
            batch_incomplete = batch_incomplete * loss_mask
            batch_pos = batch_pos * loss_mask

        x1, x2, offset_flow = self.build_inpaint_net(
            batch_incomplete,
            mask,
            config,
            reuse=reuse,
            training=training,
            padding=config.PADDING,
            exclusionmask=exclusionmask)
        if config.PRETRAIN_COARSE_NETWORK:
            batch_predicted = x1
            logger.info('Set batch_predicted to x1.')
        else:
            batch_predicted = x2
            logger.info('Set batch_predicted to x2.')
        losses = {}
        # apply mask and complete image
        batch_complete = batch_predicted * mask + batch_incomplete * (1. -
                                                                      mask)
        if exclusionmask is not None:
            batch_complete = batch_complete * loss_mask

        l1_alpha = config.COARSE_L1_ALPHA
        # local patches
        if use_local_patch:
            local_patch_batch_pos = local_patch(batch_pos, bbox)
            local_patch_batch_predicted = local_patch(batch_predicted, bbox)
            local_patch_x1 = local_patch(x1, bbox)
            local_patch_x2 = local_patch(x2, bbox)
            local_patch_batch_complete = local_patch(batch_complete, bbox)
            local_patch_mask = local_patch(mask, bbox)

            losses['l1_loss'] = l1_alpha * tf.reduce_mean(
                tf.abs(local_patch_batch_pos - local_patch_x1) *
                spatial_discounting_mask(config) *
                (loss_mask if exclusionmask is not None else 1))
            if not config.PRETRAIN_COARSE_NETWORK:
                losses['l1_loss'] += tf.reduce_mean(
                    tf.abs(local_patch_batch_pos - local_patch_x2) *
                    spatial_discounting_mask(config) *
                    (loss_mask if exclusionmask is not None else 1))

            losses['ae_loss'] = l1_alpha * tf.reduce_mean(
                tf.abs(batch_pos - x1) * (1. - mask) *
                (loss_mask if exclusionmask is not None else 1))
            if not config.PRETRAIN_COARSE_NETWORK:
                losses['ae_loss'] += tf.reduce_mean(
                    tf.abs(batch_pos - x2) * (1. - mask) *
                    (loss_mask if exclusionmask is not None else 1))
            losses['ae_loss'] /= tf.reduce_mean(1. - mask)
        else:
            losses['l1_loss'] = l1_alpha * tf.reduce_mean(
                tf.abs(batch_pos - x1) *
                (loss_mask if exclusionmask is not None else 1))
            if not config.PRETRAIN_COARSE_NETWORK:
                losses['l1_loss'] += tf.reduce_mean(
                    tf.abs(batch_pos - x2) *
                    (loss_mask if exclusionmask is not None else 1))

        if summary:
            scalar_summary('losses/l1_loss', losses['l1_loss'])
            if use_local_patch:
                scalar_summary('losses/ae_loss', losses['ae_loss'])
            img_size = [dim for dim in batch_incomplete.shape]
            img_size[2] = 5
            border = tf.zeros(tf.TensorShape(img_size))
            viz_img = [
                batch_pos, border, batch_incomplete, border, batch_complete,
                border
            ]
            if not config.PRETRAIN_COARSE_NETWORK:
                batch_complete_coarse = x1 * mask + batch_incomplete * (1. -
                                                                        mask)
                viz_img.append(batch_complete_coarse)
                viz_img.append(border)
            if offset_flow is not None:
                scale = 2 << len(offset_flow)
                for flow in offset_flow:
                    viz_img.append(
                        resize(flow,
                               scale=scale,
                               func=tf.image.resize_nearest_neighbor))
                    viz_img.append(border)
                    scale >>= 1
            images_summary(tf.concat(viz_img, axis=2),
                           'raw_incomplete_predicted_complete',
                           config.VIZ_MAX_OUT)

        # gan
        batch_pos_neg = tf.concat([batch_pos, batch_complete], axis=0)
        # local deterministic patch
        if config.GAN_WITH_MASK:
            batch_pos_neg = tf.concat([
                batch_pos_neg,
                tf.tile(mask, [config.BATCH_SIZE * 2, 1, 1, 1])
            ],
                                      axis=3)
        # wgan with gradient penalty
        if config.GAN == 'wgan_gp':
            if not use_local_patch:
                raise Exception('wgan_gp requires global and local patch')
            local_patch_batch_pos_neg = tf.concat(
                [local_patch_batch_pos, local_patch_batch_complete], 0)
            # seperate gan
            pos_neg_local, pos_neg_global = self.build_wgan_discriminator(
                local_patch_batch_pos_neg,
                batch_pos_neg,
                training=training,
                reuse=reuse)
            pos_local, neg_local = tf.split(pos_neg_local, 2)
            pos_global, neg_global = tf.split(pos_neg_global, 2)
            # wgan loss
            g_loss_local, d_loss_local = gan_wgan_loss(pos_local,
                                                       neg_local,
                                                       name='gan/local_gan')
            g_loss_global, d_loss_global = gan_wgan_loss(pos_global,
                                                         neg_global,
                                                         name='gan/global_gan')
            losses[
                'g_loss'] = config.GLOBAL_WGAN_LOSS_ALPHA * g_loss_global + g_loss_local
            losses['d_loss'] = d_loss_global + d_loss_local
            # gp
            interpolates_local = random_interpolates(
                local_patch_batch_pos, local_patch_batch_complete)
            interpolates_global = random_interpolates(batch_pos,
                                                      batch_complete)
            dout_local, dout_global = self.build_wgan_discriminator(
                interpolates_local, interpolates_global, reuse=True)
            # apply penalty
            penalty_local = gradients_penalty(interpolates_local,
                                              dout_local,
                                              mask=local_patch_mask)
            penalty_global = gradients_penalty(interpolates_global,
                                               dout_global,
                                               mask=mask)
            losses['gp_loss'] = config.WGAN_GP_LAMBDA * (penalty_local +
                                                         penalty_global)
            losses['d_loss'] = losses['d_loss'] + losses['gp_loss']
            if summary and not config.PRETRAIN_COARSE_NETWORK:
                gradients_summary(g_loss_local,
                                  batch_predicted,
                                  name='g_loss_local')
                gradients_summary(g_loss_global,
                                  batch_predicted,
                                  name='g_loss_global')
                scalar_summary('convergence/d_loss', losses['d_loss'])
                scalar_summary('convergence/local_d_loss', d_loss_local)
                scalar_summary('convergence/global_d_loss', d_loss_global)
                scalar_summary('gan_wgan_loss/gp_loss', losses['gp_loss'])
                scalar_summary('gan_wgan_loss/gp_penalty_local', penalty_local)
                scalar_summary('gan_wgan_loss/gp_penalty_global',
                               penalty_global)
        elif config.GAN == 'sngan':
            if use_local_patch:
                raise Exception(
                    'sngan incompatible with global and local patch')
            pos_neg = self.build_sngan_discriminator(batch_pos_neg,
                                                     name='discriminator',
                                                     reuse=reuse)
            pos, neg = tf.split(pos_neg, 2)
            g_loss, d_loss = gan_hinge_loss(pos, neg)
            losses['g_loss'] = g_loss
            losses['d_loss'] = d_loss
            if summary and not config.PRETRAIN_COARSE_NETWORK:
                gradients_summary(g_loss, batch_predicted, name='g_loss')
                scalar_summary('convergence/d_loss', losses['d_loss'])
        else:
            losses['g_loss'] = 0

        if summary and not config.PRETRAIN_COARSE_NETWORK:
            # summary the magnitude of gradients from different losses w.r.t. predicted image
            gradients_summary(losses['g_loss'], batch_predicted, name='g_loss')
            gradients_summary(losses['g_loss'], x1, name='g_loss_to_x1')
            gradients_summary(losses['g_loss'], x2, name='g_loss_to_x2')
            gradients_summary(losses['l1_loss'], x1, name='l1_loss_to_x1')
            gradients_summary(losses['l1_loss'], x2, name='l1_loss_to_x2')
            if use_local_patch:
                gradients_summary(losses['ae_loss'], x1, name='ae_loss_to_x1')
                gradients_summary(losses['ae_loss'], x2, name='ae_loss_to_x2')
        if config.PRETRAIN_COARSE_NETWORK:
            losses['g_loss'] = 0
        else:
            losses['g_loss'] = config.GAN_LOSS_ALPHA * losses['g_loss']
        losses['g_loss'] += config.L1_LOSS_ALPHA * losses['l1_loss']
        logger.info('Set L1_LOSS_ALPHA to %f' % config.L1_LOSS_ALPHA)
        logger.info('Set GAN_LOSS_ALPHA to %f' % config.GAN_LOSS_ALPHA)
        if config.AE_LOSS and use_local_patch:
            losses['g_loss'] += config.AE_LOSS_ALPHA * losses['ae_loss']
            logger.info('Set AE_LOSS_ALPHA to %f' % config.AE_LOSS_ALPHA)
        g_vars = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES,
                                   'inpaint_net')
        d_vars = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES,
                                   'discriminator')
        return g_vars, d_vars, losses
예제 #6
0
    def build_graph_with_losses_down(self,
                                     FLAGS,
                                     batch_data,
                                     mask,
                                     downsample_rate,
                                     training=True,
                                     summary=False,
                                     reuse=False):
        self.mask = mask
        self.downsample_rate = downsample_rate
        if FLAGS.guided:
            batch_data, edge = batch_data
            edge = edge[:, :, :, 0:1] / 255.
            edge = tf.cast(edge > FLAGS.edge_threshold, tf.float32)
        # modified
        # batch_pos = batch_data / 127.5 - 1.
        # mean = tf.reduce_mean(batch_data * (1. - self.mask), 1:3)

        # mean = tf.reduce_mean(tf.reduce_mean(tf.reduce_mean(batch_data*(1.-self.mask), 1), 1), 1)*downsample_rate*downsample_rate
        # mean = mean[..., np.newaxis, np.newaxis, np.newaxis]
        # batch_pos = batch_data / mean - 1.

        # generate mask, 1 represents masked point
        # bbox = random_bbox(FLAGS)
        # regular_mask = bbox2mask(FLAGS, bbox, name='mask_c')
        # irregular_mask = brush_stroke_mask(FLAGS, name='mask_c')
        # mask = tf.cast(
        #     tf.logical_or(
        #         tf.cast(irregular_mask, tf.bool),
        #         tf.cast(regular_mask, tf.bool),
        #     ),
        #     tf.float32
        # )
        max = tf.reduce_max(batch_data, 1)
        max = tf.reduce_max(max, 1)
        max = max[..., np.newaxis, np.newaxis]
        min = tf.reduce_min(batch_data, 1)
        min = tf.reduce_min(min, 1)
        min = min[..., np.newaxis, np.newaxis]
        batch_pos = (batch_data - min) * 2 / (max - min) - 1.
        batch_incomplete = batch_pos * (1. - self.mask)
        if FLAGS.guided:
            edge = edge * self.mask
            xin = tf.concat([batch_incomplete, edge], axis=3)
        else:
            xin = batch_incomplete
        # x1, x2, offset_flow = self.build_inpaint_net(
        #     xin, self.mask, reuse=reuse, training=training,
        #     padding=FLAGS.padding)
        mask = np.repeat(mask, 16, axis=0)
        mask = tf.convert_to_tensor(mask, dtype=tf.float32)
        xin = tf.concat([batch_incomplete, mask], axis=3)
        x2, gen_mask = self.build_inpaint_net(xin,
                                              reuse=reuse,
                                              training=training,
                                              padding=FLAGS.padding)
        batch_predicted = x2
        losses = {}
        # apply mask and complete image
        batch_complete = batch_predicted * self.mask + batch_incomplete * (
            1. - self.mask)
        # local patches
        # x1 = batch_pos
        # losses['ae_loss'] = FLAGS.l1_loss_alpha * tf.reduce_mean(tf.abs(batch_pos - x1))
        losses['ae_loss'] = FLAGS.l1_loss_alpha * tf.reduce_mean(
            tf.abs(batch_pos - x2))
        if summary:
            scalar_summary('losses/ae_loss', losses['ae_loss'])
            if FLAGS.guided:
                viz_img = [batch_pos, batch_incomplete + edge, batch_complete]
            else:
                viz_img = [batch_pos, batch_incomplete, batch_complete]
            images_summary(tf.concat(viz_img, axis=2),
                           'raw_incomplete_predicted_complete',
                           FLAGS.viz_max_out)

        # gan
        batch_pos_neg = tf.concat([batch_pos, batch_complete], axis=0)
        if FLAGS.gan_with_mask:
            batch_pos_neg = tf.concat([
                batch_pos_neg,
                tf.tile(self.mask, [FLAGS.batch_size * 2, 1, 1, 1])
            ],
                                      axis=3)
        if FLAGS.guided:
            # conditional GANs
            batch_pos_neg = tf.concat(
                [batch_pos_neg, tf.tile(edge, [2, 1, 1, 1])], axis=3)
        # wgan with gradient penalty
        if FLAGS.gan == 'sngan':
            pos_neg = self.build_gan_discriminator(batch_pos_neg,
                                                   training=training,
                                                   reuse=reuse)
            pos, neg = tf.split(pos_neg, 2)
            g_loss, d_loss = gan_hinge_loss(pos, neg)
            losses['g_loss'] = g_loss
            losses['d_loss'] = d_loss
        else:
            raise NotImplementedError('{} not implemented.'.format(FLAGS.gan))
        if summary:
            # summary the magnitude of gradients from different losses w.r.t. predicted image
            gradients_summary(losses['g_loss'], batch_predicted, name='g_loss')
            gradients_summary(losses['g_loss'], x2, name='g_loss_to_x2')
            # gradients_summary(losses['ae_loss'], x1, name='ae_loss_to_x1')
            gradients_summary(losses['ae_loss'], x2, name='ae_loss_to_x2')
        losses['g_loss'] = FLAGS.gan_loss_alpha * losses['g_loss']
        if FLAGS.ae_loss:
            losses['g_loss'] += losses['ae_loss']
        g_vars = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES,
                                   'inpaint_net')
        d_vars = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES,
                                   'discriminator')
        return g_vars, d_vars, losses
예제 #7
0
    def build_graph_with_losses(self,
                                batch_data,
                                batch_mask,
                                batch_guide,
                                config,
                                training=True,
                                summary=False,
                                reuse=False):
        batch_pos = batch_data / 127.5 - 1.
        batch_mask = random_ff_mask(parts=8)
        batch_incomplete = batch_pos * (1. - batch_mask)
        ones_x = tf.ones_like(batch_mask)[:, :, :, 0:1]
        batch_mask = ones_x * batch_mask
        batch_guide = ones_x
        x1, x2, offset_flow = self.build_inpaint_net(batch_incomplete,
                                                     batch_mask,
                                                     batch_guide,
                                                     config,
                                                     reuse=reuse,
                                                     training=training,
                                                     padding=config.PADDING)
        if config.PRETRAIN_COARSE_NETWORK:
            batch_predicted = x1
            logger.info('Set batch_predicted to x1.')
        else:
            batch_predicted = x2
            logger.info('Set batch_predicted to x2.')
        losses = {}
        # apply mask and complete image
        batch_complete = batch_predicted * batch_mask + batch_incomplete * (
            1. - batch_mask)

        # Local patch is removed in the gated convolution. It's now AE + GAN loss
        # (https://github.com/JiahuiYu/generative_inpainting/issues/62)
        losses['ae_loss'] = config.COARSE_L1_ALPHA * tf.reduce_mean(
            tf.abs(batch_pos - x1) * (1. - batch_mask))
        losses['ae_loss'] += config.COARSE_L1_ALPHA * tf.reduce_mean(
            tf.abs(batch_pos - x2) * (1. - batch_mask))
        losses['ae_loss'] /= tf.reduce_mean(1. - batch_mask)

        if summary:
            scalar_summary('losses/ae_loss', losses['ae_loss'])
            viz_img = [
                batch_pos, batch_incomplete, batch_predicted, batch_complete
            ]  # I have included the predicted image as well to see the reconstructed image.
            if offset_flow is not None:
                viz_img.append(
                    resize(offset_flow,
                           scale=4,
                           func=tf.image.resize_nearest_neighbor))
            images_summary(tf.concat(viz_img, axis=2),
                           'raw_incomplete_predicted_complete',
                           config.VIZ_MAX_OUT)

        # gan
        batch_pos_neg = tf.concat([batch_pos, batch_complete], axis=0)
        batch_mask_all = tf.tile(batch_mask, [config.BATCH_SIZE * 2, 1, 1, 1])
        if config.GAN_WITH_MASK:
            batch_pos_neg = tf.concat([batch_pos_neg, batch_mask_all], axis=3)

        if config.GAN_WITH_GUIDE:
            batch_pos_neg = tf.concat([
                batch_pos_neg,
                tf.tile(batch_guide, [config.BATCH_SIZE * 2, 1, 1, 1])
            ],
                                      axis=3)
        # sn-pgan
        if config.GAN == 'sn_pgan':
            # sn path gan
            pos_neg = self.build_sn_pgan_discriminator(batch_pos_neg,
                                                       training=training,
                                                       reuse=reuse)
            pos_global, neg_global = tf.split(pos_neg, 2)

            # SNPGAN Loss

            g_loss_global, d_loss_global = gan_sn_pgan_loss(
                pos_global, neg_global, name='gan/global_gan')
            losses['g_loss'] = g_loss_global
            losses['d_loss'] = d_loss_global

        if summary and not config.PRETRAIN_COARSE_NETWORK:
            # summary the magnitude of gradients from different losses w.r.t. predicted image
            gradients_summary(losses['g_loss'], batch_predicted, name='g_loss')
            gradients_summary(losses['d_loss'], batch_predicted, name='d_loss')
            gradients_summary(losses['ae_loss'], x1, name='ae_loss_x1')
            gradients_summary(losses['ae_loss'], x2, name='ae_loss_x2')

        losses['g_loss'] = config.GAN_LOSS_ALPHA * losses['g_loss']
        losses['g_loss'] += config.AE_LOSS_ALPHA * losses['ae_loss']
        logger.info('Set GAN_LOSS_ALPHA to %f' % config.GAN_LOSS_ALPHA)
        g_vars = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES,
                                   'inpaint_net')
        d_vars = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES,
                                   'discriminator')
        return g_vars, d_vars, losses
예제 #8
0
    def build_graph_with_losses(
            self, FLAGS, batch_data, training=True, summary=False,
            reuse=False):
        batch_data, mask = batch_data
        mask /= 255 #(batch_size, 256, 256, 1)
        mask = tf.cast(mask > 0.5, tf.float32)
        # Normalize min max to [-1, 1]
        batch_pos = batch_data / 127.5 - 1.
        #minV = FLAGS.min_dem
        #maxV = FLAGS.max_dem
        #batch_pos = 2*(batch_data - minV)/ (maxV - minV) - 1.
        
        batch_incomplete = batch_pos*(1.-mask)
        xin = batch_incomplete
        x1, x2, offset_flow = self.build_inpaint_net(
            xin, mask, reuse=reuse, training=training,
            padding=FLAGS.padding)
        batch_predicted = x2
        losses = {}
        # apply mask and complete image
        batch_complete = batch_predicted*mask + batch_incomplete*(1.-mask)
        # local patches
        losses['ae_loss'] = FLAGS.l1_loss_alpha * tf.reduce_mean(tf.abs(batch_pos - x1))
        losses['ae_loss'] += FLAGS.l1_loss_alpha * tf.reduce_mean(tf.abs(batch_pos - x2))
        if summary:
            scalar_summary('losses/ae_loss', losses['ae_loss'])
            viz_img = [batch_pos, batch_incomplete, batch_complete]
#             if offset_flow is not None:
#                 viz_img.append(
#                     resize(offset_flow, scale=4,
#                            func=tf.image.resize_bilinear))
            images_summary(
                tf.concat(viz_img, axis=2),
                'raw_incomplete_predicted_complete', FLAGS.viz_max_out)

        # gan
        batch_pos_neg = tf.concat([batch_pos, batch_complete], axis=0)
        if FLAGS.gan_with_mask:
            batch_pos_neg = tf.concat([batch_pos_neg, tf.tile(mask, [2, 1, 1, 1])], axis=3)
        # wgan with gradient penalty
        if FLAGS.gan == 'sngan':
            pos_neg = self.build_gan_discriminator(batch_pos_neg, training=training, reuse=reuse)
            pos, neg = tf.split(pos_neg, 2)
            g_loss, d_loss = gan_hinge_loss(pos, neg)
            losses['g_loss'] = g_loss
            losses['d_loss'] = d_loss
        else:
            raise NotImplementedError('{} not implemented.'.format(FLAGS.gan))
        if summary:
            # summary the magnitude of gradients from different losses w.r.t. predicted image
            #gradients_summary(losses['g_loss'], batch_predicted, name='g_loss')
            gradients_summary(losses['g_loss'], x2, name='g_loss_to_x2')
            # gradients_summary(losses['ae_loss'], x1, name='ae_loss_to_x1')
            gradients_summary(losses['ae_loss'], x2, name='ae_loss_to_x2')
        losses['g_loss'] = FLAGS.gan_loss_alpha * losses['g_loss']
        if FLAGS.ae_loss:
            losses['g_loss'] += losses['ae_loss']
        g_vars = tf.get_collection(
            tf.GraphKeys.TRAINABLE_VARIABLES, 'inpaint_net')
        d_vars = tf.get_collection(
            tf.GraphKeys.TRAINABLE_VARIABLES, 'discriminator')
        return g_vars, d_vars, losses