def build_graph_with_losses(self, data, config, reuse=True, summary=False):
        """Build training graph and losses.

        Args:
            data : dataset for sampling
            config : config of training

        Returns: vars of generator, vars of discriminator, loss of training

        """
        images = data.data_pipeline(config.BATCH_SIZE)
        images = images/127.5 - 1.
        z = tf.random_uniform([config.BATCH_SIZE, 1, 1, 512], -1, 1, name='z')
        fake = self.G_paper(
            z,
            config.LAST_RESOLUTION, config.CURRENT_RESOLUTION,
            reuse=reuse)

        if summary:
            images_summary(images, 'real_images', config.VIZ_MAX_OUT)
            images_summary(fake, 'fake_images', config.VIZ_MAX_OUT)

        neg = self.D_paper(
            fake,
            config.LAST_RESOLUTION, config.CURRENT_RESOLUTION,
            reuse=reuse)
        pos = self.D_paper(
            images,
            config.LAST_RESOLUTION, config.CURRENT_RESOLUTION,
            reuse=True)
        g_loss, d_loss = gan_wgan_loss(pos, neg)

        ri = random_interpolates(images, fake)
        ri_out = self.D_paper(
            ri,
            config.LAST_RESOLUTION, config.CURRENT_RESOLUTION,
            reuse=True)
        ri_loss = gradients_penalty(ri, ri_out)
        d_loss = d_loss + config.LOSS['iwass_lambda'] * ri_loss
        losses = {'g_loss': g_loss, 'd_loss': d_loss, 'ri_loss': ri_loss}

        g_vars = tf.get_collection(
            tf.GraphKeys.TRAINABLE_VARIABLES, 'G_paper')
        d_vars = tf.get_collection(
            tf.GraphKeys.TRAINABLE_VARIABLES, 'D_paper')
        return g_vars, d_vars, losses
    def build_graph_with_losses(self,
                                batch_data,
                                config,
                                training=True,
                                summary=False,
                                reuse=False):
        batch_pos = batch_data / 127.5 - 1.

        mask = bbox2mask(config, name='mask_c')
        batch_incomplete = batch_pos * (1. - mask)
        x1, x2, offset_flow = self.build_inpaint_net(batch_incomplete,
                                                     mask,
                                                     config,
                                                     reuse=reuse,
                                                     training=training,
                                                     padding=config.PADDING)
        if config.PRETRAIN_COARSE_NETWORK:
            batch_predicted = x1
            logger.info('Set batch_predicted to x1.')
        else:
            batch_predicted = x2
            logger.info('Set batch_predicted to x2.')
        losses = {}
        batch_complete = batch_predicted * mask + batch_incomplete * (1. -
                                                                      mask)
        local_patch_batch_pos = local_patch(batch_pos, mask)
        local_patch_batch_predicted = local_patch(batch_predicted, mask)
        local_patch_x1 = local_patch(x1, mask)
        local_patch_x2 = local_patch(x2, mask)
        local_patch_batch_complete = local_patch(batch_complete, mask)
        l1_alpha = config.COARSE_L1_ALPHA
        losses['l1_loss'] = l1_alpha * tf.reduce_mean(
            tf.abs(local_patch_batch_pos -
                   local_patch_x1))  #*spatial_discounting_mask(config))
        if not config.PRETRAIN_COARSE_NETWORK:
            losses['l1_loss'] += tf.reduce_mean(
                tf.abs(local_patch_batch_pos -
                       local_patch_x2))  #*spatial_discounting_mask(config))
        losses['ae_loss'] = l1_alpha * tf.reduce_mean(
            tf.abs(batch_pos - x1) * (1. - mask))
        if not config.PRETRAIN_COARSE_NETWORK:
            losses['ae_loss'] += tf.reduce_mean(
                tf.abs(batch_pos - x2) * (1. - mask))
        losses['ae_loss'] /= tf.reduce_mean(1. - mask)
        if summary:
            scalar_summary('losses/l1_loss', losses['l1_loss'])
            scalar_summary('losses/ae_loss', losses['ae_loss'])
            viz_img = [batch_pos, batch_incomplete, batch_complete]
            if offset_flow is not None:
                viz_img.append(
                    resize(offset_flow,
                           scale=4,
                           func=tf.image.resize_nearest_neighbor))
            images_summary(tf.concat(viz_img, axis=2),
                           'raw_incomplete_predicted_complete',
                           config.VIZ_MAX_OUT)

        batch_pos_neg = tf.concat([batch_pos, batch_complete], axis=0)
        local_patch_batch_pos_neg = tf.concat(
            [local_patch_batch_pos, local_patch_batch_complete], 0)
        if config.GAN_WITH_MASK:
            batch_pos_neg = tf.concat([
                batch_pos_neg,
                tf.tile(mask, [config.BATCH_SIZE * 2, 1, 1, 1])
            ],
                                      axis=3)

        if config.GAN == 'snpatch_gan':
            pos_neg = self.build_SNGAN_discriminator(local_patch_batch_pos_neg,
                                                     training=training,
                                                     reuse=reuse)
            pos, neg = tf.split(pos_neg, 2)
            sn_gloss, sn_dloss = self.gan_hinge_loss(pos,
                                                     neg,
                                                     name="gan/hinge_loss")
            losses['g_loss'] = config.GLOBAL_WGAN_LOSS_ALPHA * sn_gloss
            losses['d_loss'] = sn_dloss
            interpolates = random_interpolates(a1, a2)
            dout = self.build_SNGAN_discriminator(interpolates, reuse=True)
            penalty = gradients_penalty(interpolates, dout, mask=mask)
            losses['gp_loss'] = config.WGAN_GP_LAMBDA * penalty
            losses['d_loss'] = losses['d_loss'] + losses['gp_loss']
            if summary and not config.PRETRAIN_COARSE_NETWORK:
                gradients_summary(sn_gloss,
                                  batch_predicted,
                                  name='g_loss_local')
                scalar_summary('convergence/d_loss', losses['d_loss'])
                scalar_summary('convergence/local_d_loss', sn_dloss)
                scalar_summary('gan_wgan_loss/gp_loss', losses['gp_loss'])
                scalar_summary('gan_wgan_loss/gp_penalty_local', penalty)

        if summary and not config.PRETRAIN_COARSE_NETWORK:
            gradients_summary(losses['g_loss'], batch_predicted, name='g_loss')
            gradients_summary(losses['g_loss'], x1, name='g_loss_to_x1')
            gradients_summary(losses['g_loss'], x2, name='g_loss_to_x2')
            gradients_summary(losses['l1_loss'], x1, name='l1_loss_to_x1')
            gradients_summary(losses['l1_loss'], x2, name='l1_loss_to_x2')
            gradients_summary(losses['ae_loss'], x1, name='ae_loss_to_x1')
            gradients_summary(losses['ae_loss'], x2, name='ae_loss_to_x2')
        if config.PRETRAIN_COARSE_NETWORK:
            losses['g_loss'] = 0
        else:
            losses['g_loss'] = config.GAN_LOSS_ALPHA * losses['g_loss']
        losses['g_loss'] += config.L1_LOSS_ALPHA * losses['l1_loss']
        logger.info('Set L1_LOSS_ALPHA to %f' % config.L1_LOSS_ALPHA)
        logger.info('Set GAN_LOSS_ALPHA to %f' % config.GAN_LOSS_ALPHA)
        if config.AE_LOSS:
            losses['g_loss'] += config.AE_LOSS_ALPHA * losses['ae_loss']
            logger.info('Set AE_LOSS_ALPHA to %f' % config.AE_LOSS_ALPHA)
        g_vars = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES,
                                   'inpaint_net')
        d_vars = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES,
                                   'discriminator')
        return g_vars, d_vars, losses
    def build_graph_with_losses(self, batch_data, config, training=True,
                                summary=False, reuse=False):
        batch_pos = batch_data / 127.5 - 1.
        # generate mask, 1 represents masked point
        bbox = random_bbox(config)
        mask = bbox2mask(bbox, config, name='mask_c')
        batch_incomplete = batch_pos*(1.-mask)
        x1, x2, offset_flow = self.build_inpaint_net(
            batch_incomplete, mask, config, reuse=reuse, training=training,
            padding=config.PADDING)
        if config.PRETRAIN_COARSE_NETWORK:
            batch_predicted = x1
            logger.info('Set batch_predicted to x1.')
        else:
            batch_predicted = x2
            logger.info('Set batch_predicted to x2.')
        losses = {}
        # apply mask and complete image
        batch_complete = batch_predicted*mask + batch_incomplete*(1.-mask)
        # local patches
        local_patch_batch_pos = local_patch(batch_pos, bbox)
        local_patch_batch_predicted = local_patch(batch_predicted, bbox)
        local_patch_x1 = local_patch(x1, bbox)
        local_patch_x2 = local_patch(x2, bbox)
        local_patch_batch_complete = local_patch(batch_complete, bbox)
        local_patch_mask = local_patch(mask, bbox)
        l1_alpha = config.COARSE_L1_ALPHA
        losses['l1_loss'] = l1_alpha * tf.reduce_mean(tf.abs(local_patch_batch_pos - local_patch_x1)*spatial_discounting_mask(config))
        if not config.PRETRAIN_COARSE_NETWORK:
            losses['l1_loss'] += tf.reduce_mean(tf.abs(local_patch_batch_pos - local_patch_x2)*spatial_discounting_mask(config))
        losses['ae_loss'] = l1_alpha * tf.reduce_mean(tf.abs(batch_pos - x1) * (1.-mask))
        if not config.PRETRAIN_COARSE_NETWORK:
            losses['ae_loss'] += tf.reduce_mean(tf.abs(batch_pos - x2) * (1.-mask))
        losses['ae_loss'] /= tf.reduce_mean(1.-mask)
        if summary:
            scalar_summary('losses/l1_loss', losses['l1_loss'])
            scalar_summary('losses/ae_loss', losses['ae_loss'])
            viz_img = [batch_pos, batch_incomplete, batch_complete]
            if offset_flow is not None:
                viz_img.append(
                    resize(offset_flow, scale=4,
                           func=tf.image.resize_nearest_neighbor))
            images_summary(
                tf.concat(viz_img, axis=2),
                'raw_incomplete_predicted_complete', config.VIZ_MAX_OUT)

        # gan
        batch_pos_neg = tf.concat([batch_pos, batch_complete], axis=0)
        # local deterministic patch
        local_patch_batch_pos_neg = tf.concat([local_patch_batch_pos, local_patch_batch_complete], 0)
        if config.GAN_WITH_MASK:
            batch_pos_neg = tf.concat([batch_pos_neg, tf.tile(mask, [config.BATCH_SIZE*2, 1, 1, 1])], axis=3)
        # wgan with gradient penalty
        if config.GAN == 'wgan_gp':
            # seperate gan
            pos_neg_local, pos_neg_global = self.build_wgan_discriminator(local_patch_batch_pos_neg, batch_pos_neg, training=training, reuse=reuse)
            pos_local, neg_local = tf.split(pos_neg_local, 2)
            pos_global, neg_global = tf.split(pos_neg_global, 2)
            # wgan loss
            g_loss_local, d_loss_local = gan_wgan_loss(pos_local, neg_local, name='gan/local_gan')
            g_loss_global, d_loss_global = gan_wgan_loss(pos_global, neg_global, name='gan/global_gan')
            losses['g_loss'] = config.GLOBAL_WGAN_LOSS_ALPHA * g_loss_global + g_loss_local
            losses['d_loss'] = d_loss_global + d_loss_local
            # gp
            interpolates_local = random_interpolates(local_patch_batch_pos, local_patch_batch_complete)
            interpolates_global = random_interpolates(batch_pos, batch_complete)
            dout_local, dout_global = self.build_wgan_discriminator(
                interpolates_local, interpolates_global, reuse=True)
            # apply penalty
            penalty_local = gradients_penalty(interpolates_local, dout_local, mask=local_patch_mask)
            penalty_global = gradients_penalty(interpolates_global, dout_global, mask=mask)
            losses['gp_loss'] = config.WGAN_GP_LAMBDA * (penalty_local + penalty_global)
            losses['d_loss'] = losses['d_loss'] + losses['gp_loss']
            if summary and not config.PRETRAIN_COARSE_NETWORK:
                gradients_summary(g_loss_local, batch_predicted, name='g_loss_local')
                gradients_summary(g_loss_global, batch_predicted, name='g_loss_global')
                scalar_summary('convergence/d_loss', losses['d_loss'])
                scalar_summary('convergence/local_d_loss', d_loss_local)
                scalar_summary('convergence/global_d_loss', d_loss_global)
                scalar_summary('gan_wgan_loss/gp_loss', losses['gp_loss'])
                scalar_summary('gan_wgan_loss/gp_penalty_local', penalty_local)
                scalar_summary('gan_wgan_loss/gp_penalty_global', penalty_global)

        if summary and not config.PRETRAIN_COARSE_NETWORK:
            # summary the magnitude of gradients from different losses w.r.t. predicted image
            gradients_summary(losses['g_loss'], batch_predicted, name='g_loss')
            gradients_summary(losses['g_loss'], x1, name='g_loss_to_x1')
            gradients_summary(losses['g_loss'], x2, name='g_loss_to_x2')
            gradients_summary(losses['l1_loss'], x1, name='l1_loss_to_x1')
            gradients_summary(losses['l1_loss'], x2, name='l1_loss_to_x2')
            gradients_summary(losses['ae_loss'], x1, name='ae_loss_to_x1')
            gradients_summary(losses['ae_loss'], x2, name='ae_loss_to_x2')
        if config.PRETRAIN_COARSE_NETWORK:
            losses['g_loss'] = 0
        else:
            losses['g_loss'] = config.GAN_LOSS_ALPHA * losses['g_loss']
        losses['g_loss'] += config.L1_LOSS_ALPHA * losses['l1_loss']
        logger.info('Set L1_LOSS_ALPHA to %f' % config.L1_LOSS_ALPHA)
        logger.info('Set GAN_LOSS_ALPHA to %f' % config.GAN_LOSS_ALPHA)
        if config.AE_LOSS:
            losses['g_loss'] += config.AE_LOSS_ALPHA * losses['ae_loss']
            logger.info('Set AE_LOSS_ALPHA to %f' % config.AE_LOSS_ALPHA)
        g_vars = tf.get_collection(
            tf.GraphKeys.TRAINABLE_VARIABLES, 'inpaint_net')
        d_vars = tf.get_collection(
            tf.GraphKeys.TRAINABLE_VARIABLES, 'discriminator')
        return g_vars, d_vars, losses
    def build_graph_with_losses(self,
                                batch_data,
                                config,
                                training=True,
                                reuse=False):
        batch_pos = batch_data / 127.5 - 1.
        # generate mask, 1 represents masked point
        bbox = random_bbox(config)
        mask = bbox2mask(bbox, config, name='mask_c')
        batch_incomplete = batch_pos * (1. - mask)
        x1, x2, offset_flow = self.build_inpaint_net(batch_incomplete,
                                                     mask,
                                                     config,
                                                     reuse=reuse,
                                                     training=training,
                                                     padding=config.PADDING)
        if config.PRETRAIN_COARSE_NETWORK:
            batch_predicted = x1
        else:
            batch_predicted = x2
        losses = {}
        # apply mask and complete image
        batch_complete = batch_predicted * mask + batch_incomplete * (1. -
                                                                      mask)
        # local patches
        local_patch_batch_pos = local_patch(batch_pos, bbox)
        local_patch_batch_predicted = local_patch(batch_predicted, bbox)
        local_patch_x1 = local_patch(x1, bbox)
        local_patch_x2 = local_patch(x2, bbox)
        local_patch_batch_complete = local_patch(batch_complete, bbox)
        local_patch_mask = local_patch(mask, bbox)
        l1_alpha = config.COARSE_L1_ALPHA
        losses['l1_loss'] = l1_alpha * tf.reduce_mean(
            tf.abs(local_patch_batch_pos - local_patch_x1) *
            spatial_discounting_mask(config))
        if not config.PRETRAIN_COARSE_NETWORK:
            losses['l1_loss'] += tf.reduce_mean(
                tf.abs(local_patch_batch_pos - local_patch_x2) *
                spatial_discounting_mask(config))
        losses['ae_loss'] = l1_alpha * tf.reduce_mean(
            tf.abs(batch_pos - x1) * (1. - mask))
        if not config.PRETRAIN_COARSE_NETWORK:
            losses['ae_loss'] += tf.reduce_mean(
                tf.abs(batch_pos - x2) * (1. - mask))
        losses['ae_loss'] /= tf.reduce_mean(1. - mask)

        # gan
        batch_pos_neg = tf.concat([batch_pos, batch_complete], axis=0)
        # local deterministic patch
        local_patch_batch_pos_neg = tf.concat(
            [local_patch_batch_pos, local_patch_batch_complete], 0)
        if config.GAN_WITH_MASK:
            batch_pos_neg = tf.concat([
                batch_pos_neg,
                tf.tile(mask, [config.BATCH_SIZE * 2, 1, 1, 1])
            ],
                                      axis=3)
        # wgan with gradient penalty
        if config.GAN == 'wgan_gp':
            # seperate gan
            pos_neg_local, pos_neg_global = self.build_wgan_discriminator(
                local_patch_batch_pos_neg,
                batch_pos_neg,
                training=training,
                reuse=reuse)
            pos_local, neg_local = tf.split(pos_neg_local, 2)
            pos_global, neg_global = tf.split(pos_neg_global, 2)
            # wgan loss
            g_loss_local, d_loss_local = gan_wgan_loss(pos_local,
                                                       neg_local,
                                                       name='gan/local_gan')
            g_loss_global, d_loss_global = gan_wgan_loss(pos_global,
                                                         neg_global,
                                                         name='gan/global_gan')
            losses[
                'g_loss'] = config.GLOBAL_WGAN_LOSS_ALPHA * g_loss_global + g_loss_local
            losses['d_loss'] = d_loss_global + d_loss_local
            # gp
            interpolates_local = random_interpolates(
                local_patch_batch_pos, local_patch_batch_complete)
            interpolates_global = random_interpolates(batch_pos,
                                                      batch_complete)
            dout_local, dout_global = self.build_wgan_discriminator(
                interpolates_local, interpolates_global, reuse=True)
            # apply penalty
            penalty_local = gradients_penalty(interpolates_local,
                                              dout_local,
                                              mask=local_patch_mask)
            penalty_global = gradients_penalty(interpolates_global,
                                               dout_global,
                                               mask=mask)
            losses['gp_loss'] = config.WGAN_GP_LAMBDA * (penalty_local +
                                                         penalty_global)
            losses['d_loss'] = losses['d_loss'] + losses['gp_loss']

        if config.PRETRAIN_COARSE_NETWORK:
            losses['g_loss'] = 0
        else:
            losses['g_loss'] = config.GAN_LOSS_ALPHA * losses['g_loss']
        losses['g_loss'] += config.L1_LOSS_ALPHA * losses['l1_loss']
        if config.AE_LOSS:
            losses['g_loss'] += config.AE_LOSS_ALPHA * losses['ae_loss']
        g_vars = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES,
                                   'inpaint_net')
        d_vars = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES,
                                   'discriminator')
        return g_vars, d_vars, losses
Exemple #5
0
    def build_graph_with_losses(self,
                                batch_data,
                                batch_mask,
                                batch_guide,
                                config,
                                training=True,
                                summary=False,
                                reuse=False):
        batch_pos = batch_data / 127.5 - 1.
        # generate mask, 1 represents masked point[]
        #print(batch_data, batch_mask)
        if batch_mask is None:
            batch_mask = random_ff_mask(config)
        else:
            pass
            #batch_mask = tf.reshape(batch_mask[0], [1, *batch_mask.get_shape().as_list()[1:]])
        #print(batch_mask.shape)
        #rint()
        batch_incomplete = batch_pos * (1. - batch_mask)
        ones_x = tf.ones_like(batch_mask)[:, :, :, 0:1]
        batch_mask = ones_x * batch_mask
        batch_guide = ones_x
        x1, x2, offset_flow = self.build_inpaint_net(batch_incomplete,
                                                     batch_mask,
                                                     batch_mask,
                                                     config,
                                                     reuse=reuse,
                                                     training=training,
                                                     padding=config.PADDING)
        if config.PRETRAIN_COARSE_NETWORK:
            batch_predicted = x1
            logger.info('Set batch_predicted to x1.')
        else:
            batch_predicted = x2
            logger.info('Set batch_predicted to x2.')
        losses = {}
        # apply mask and complete image
        batch_complete = batch_predicted * batch_mask + batch_incomplete * (
            1. - batch_mask)

        # local patches
        local_patch_batch_pos = mask_patch(batch_pos, batch_mask)
        local_patch_batch_predicted = mask_patch(batch_predicted, batch_mask)
        local_patch_x1 = mask_patch(x1, batch_mask)
        local_patch_x2 = mask_patch(x2, batch_mask)
        local_patch_batch_complete = mask_patch(batch_complete, batch_mask)
        #local_patch_mask = mask_patch(mask, bbox)

        # local patch l1 loss hole+out same as partial convolution
        l1_alpha = config.COARSE_L1_ALPHA
        losses['l1_loss'] = l1_alpha * tf.reduce_mean(
            tf.abs(local_patch_batch_pos -
                   local_patch_x1))  # *spatial_discounting_mask(config))
        if not config.PRETRAIN_COARSE_NETWORK:
            losses['l1_loss'] += tf.reduce_mean(
                tf.abs(local_patch_batch_pos -
                       local_patch_x2))  # *spatial_discounting_mask(config))
        losses['ae_loss'] = l1_alpha * tf.reduce_mean(
            tf.abs(batch_pos - x1) * (1. - batch_mask))
        if not config.PRETRAIN_COARSE_NETWORK:
            losses['ae_loss'] += tf.reduce_mean(
                tf.abs(batch_pos - x2) * (1. - batch_mask))
        losses['ae_loss'] /= tf.reduce_mean(1. - batch_mask)

        if summary:
            scalar_summary('losses/l1_loss', losses['l1_loss'])
            scalar_summary('losses/ae_loss', losses['ae_loss'])
            viz_img = [batch_pos, batch_incomplete, batch_complete]
            if offset_flow is not None:
                viz_img.append(
                    resize(offset_flow,
                           scale=4,
                           func=tf.image.resize_nearest_neighbor))
            images_summary(tf.concat(viz_img, axis=2),
                           'raw_incomplete_predicted_complete',
                           config.VIZ_MAX_OUT)

        # gan
        batch_pos_neg = tf.concat([batch_pos, batch_complete], axis=0)
        if config.MASKFROMFILE:
            batch_mask_all = tf.tile(batch_mask, [2, 1, 1, 1])
            #batch_mask = tf.tile(batch_mask, [config.BATCH_SIZE, 1, 1, 1])
        else:
            batch_mask_all = tf.tile(batch_mask,
                                     [config.BATCH_SIZE * 2, 1, 1, 1])
            batch_mask = tf.tile(batch_mask, [config.BATCH_SIZE, 1, 1, 1])
        if config.GAN_WITH_MASK:
            batch_pos_neg = tf.concat([batch_pos_neg, batch_mask_all], axis=3)

        if config.GAN_WITH_GUIDE:
            batch_pos_neg = tf.concat([
                batch_pos_neg,
                tf.tile(batch_guide, [config.BATCH_SIZE * 2, 1, 1, 1])
            ],
                                      axis=3)
        #batch_pos_, batch_complete_ = tf.split(axis, value, num_split, name=None)
        # sn-pgan with gradient penalty
        if config.GAN == 'sn_pgan':
            # sn path gan
            pos_neg = self.build_sn_pgan_discriminator(batch_pos_neg,
                                                       training=training,
                                                       reuse=reuse)
            pos_global, neg_global = tf.split(pos_neg, 2)

            # wgan loss
            #g_loss_local, d_loss_local = gan_wgan_loss(pos_local, neg_local, name='gan/local_gan')
            g_loss_global, d_loss_global = gan_sn_pgan_loss(
                pos_global, neg_global, name='gan/global_gan')
            losses['g_loss'] = config.GLOBAL_WGAN_LOSS_ALPHA * g_loss_global
            losses['d_loss'] = d_loss_global
            # gp

            # Random Interpolate between true and false

            interpolates_global = random_interpolates(
                tf.concat([batch_pos, batch_mask], axis=3),
                tf.concat([batch_complete, batch_mask], axis=3))
            dout_global = self.build_sn_pgan_discriminator(interpolates_global,
                                                           reuse=True)

            # apply penalty
            penalty_global = gradients_penalty(interpolates_global,
                                               dout_global,
                                               mask=batch_mask)
            losses['gp_loss'] = config.WGAN_GP_LAMBDA * penalty_global
            #losses['d_loss'] = losses['d_loss'] + losses['gp_loss']

            if summary and not config.PRETRAIN_COARSE_NETWORK:
                gradients_summary(g_loss_global,
                                  batch_predicted,
                                  name='g_loss_global')
                scalar_summary('convergence/d_loss', losses['d_loss'])
                scalar_summary('convergence/global_d_loss', d_loss_global)
                scalar_summary('gan_sn_pgan_loss/gp_loss', losses['gp_loss'])
                scalar_summary('gan_sn_pgan_loss/gp_penalty_global',
                               penalty_global)

        if summary and not config.PRETRAIN_COARSE_NETWORK:
            # summary the magnitude of gradients from different losses w.r.t. predicted image
            gradients_summary(losses['g_loss'], batch_predicted, name='g_loss')
            gradients_summary(losses['g_loss'], x1, name='g_loss_to_x1')
            gradients_summary(losses['g_loss'], x2, name='g_loss_to_x2')
            gradients_summary(losses['l1_loss'], x1, name='l1_loss_to_x1')
            gradients_summary(losses['l1_loss'], x2, name='l1_loss_to_x2')
            gradients_summary(losses['ae_loss'], x1, name='ae_loss_to_x1')
            gradients_summary(losses['ae_loss'], x2, name='ae_loss_to_x2')
        if config.PRETRAIN_COARSE_NETWORK:
            losses['g_loss'] = 0
        else:
            losses['g_loss'] = config.GAN_LOSS_ALPHA * losses['g_loss']
        losses['g_loss'] += config.L1_LOSS_ALPHA * losses['l1_loss']
        logger.info('Set L1_LOSS_ALPHA to %f' % config.L1_LOSS_ALPHA)
        logger.info('Set GAN_LOSS_ALPHA to %f' % config.GAN_LOSS_ALPHA)
        if config.AE_LOSS:
            losses['g_loss'] += config.AE_LOSS_ALPHA * losses['ae_loss']
            logger.info('Set AE_LOSS_ALPHA to %f' % config.AE_LOSS_ALPHA)
        g_vars = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES,
                                   'inpaint_net')
        d_vars = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES,
                                   'discriminator')
        return g_vars, d_vars, losses
Exemple #6
0
    def build_graph_with_losses(self,
                                batch_data,
                                config,
                                training=True,
                                summary=False,
                                reuse=False,
                                exclusionmask=None,
                                mask=None):
        batch_pos = batch_data / 127.5 - 1.
        # generate mask, 1 represents masked point
        use_local_patch = False
        if mask is None:
            bbox = random_bbox(config)
            mask = bbox2mask(bbox, config, name='mask_c')
            if config.GAN == 'wgan_gp':
                use_local_patch = True
        else:
            #bbox = (0, 0, config.IMG_SHAPES[0], config.IMG_SHAPES[1])
            mask = tf.cast(tf.less(0.5, mask[:, :, :, 0:1]), tf.float32)
            if config.INVERT_MASK:
                mask = 1 - mask

        batch_incomplete = batch_pos * (1. - mask)

        if exclusionmask is not None:
            if config.INVERT_EXCLUSIONMASK:
                loss_mask = tf.cast(tf.less(0.5, exclusionmask[:, :, :, 0:1]),
                                    tf.float32)  #keep white parts
            else:
                loss_mask = tf.cast(tf.less(exclusionmask[:, :, :, 0:1], 0.5),
                                    tf.float32)  #keep black parts
            batch_incomplete = batch_incomplete * loss_mask
            batch_pos = batch_pos * loss_mask

        x1, x2, offset_flow = self.build_inpaint_net(
            batch_incomplete,
            mask,
            config,
            reuse=reuse,
            training=training,
            padding=config.PADDING,
            exclusionmask=exclusionmask)
        if config.PRETRAIN_COARSE_NETWORK:
            batch_predicted = x1
            logger.info('Set batch_predicted to x1.')
        else:
            batch_predicted = x2
            logger.info('Set batch_predicted to x2.')
        losses = {}
        # apply mask and complete image
        batch_complete = batch_predicted * mask + batch_incomplete * (1. -
                                                                      mask)
        if exclusionmask is not None:
            batch_complete = batch_complete * loss_mask

        l1_alpha = config.COARSE_L1_ALPHA
        # local patches
        if use_local_patch:
            local_patch_batch_pos = local_patch(batch_pos, bbox)
            local_patch_batch_predicted = local_patch(batch_predicted, bbox)
            local_patch_x1 = local_patch(x1, bbox)
            local_patch_x2 = local_patch(x2, bbox)
            local_patch_batch_complete = local_patch(batch_complete, bbox)
            local_patch_mask = local_patch(mask, bbox)

            losses['l1_loss'] = l1_alpha * tf.reduce_mean(
                tf.abs(local_patch_batch_pos - local_patch_x1) *
                spatial_discounting_mask(config) *
                (loss_mask if exclusionmask is not None else 1))
            if not config.PRETRAIN_COARSE_NETWORK:
                losses['l1_loss'] += tf.reduce_mean(
                    tf.abs(local_patch_batch_pos - local_patch_x2) *
                    spatial_discounting_mask(config) *
                    (loss_mask if exclusionmask is not None else 1))

            losses['ae_loss'] = l1_alpha * tf.reduce_mean(
                tf.abs(batch_pos - x1) * (1. - mask) *
                (loss_mask if exclusionmask is not None else 1))
            if not config.PRETRAIN_COARSE_NETWORK:
                losses['ae_loss'] += tf.reduce_mean(
                    tf.abs(batch_pos - x2) * (1. - mask) *
                    (loss_mask if exclusionmask is not None else 1))
            losses['ae_loss'] /= tf.reduce_mean(1. - mask)
        else:
            losses['l1_loss'] = l1_alpha * tf.reduce_mean(
                tf.abs(batch_pos - x1) *
                (loss_mask if exclusionmask is not None else 1))
            if not config.PRETRAIN_COARSE_NETWORK:
                losses['l1_loss'] += tf.reduce_mean(
                    tf.abs(batch_pos - x2) *
                    (loss_mask if exclusionmask is not None else 1))

        if summary:
            scalar_summary('losses/l1_loss', losses['l1_loss'])
            if use_local_patch:
                scalar_summary('losses/ae_loss', losses['ae_loss'])
            img_size = [dim for dim in batch_incomplete.shape]
            img_size[2] = 5
            border = tf.zeros(tf.TensorShape(img_size))
            viz_img = [
                batch_pos, border, batch_incomplete, border, batch_complete,
                border
            ]
            if not config.PRETRAIN_COARSE_NETWORK:
                batch_complete_coarse = x1 * mask + batch_incomplete * (1. -
                                                                        mask)
                viz_img.append(batch_complete_coarse)
                viz_img.append(border)
            if offset_flow is not None:
                scale = 2 << len(offset_flow)
                for flow in offset_flow:
                    viz_img.append(
                        resize(flow,
                               scale=scale,
                               func=tf.image.resize_nearest_neighbor))
                    viz_img.append(border)
                    scale >>= 1
            images_summary(tf.concat(viz_img, axis=2),
                           'raw_incomplete_predicted_complete',
                           config.VIZ_MAX_OUT)

        # gan
        batch_pos_neg = tf.concat([batch_pos, batch_complete], axis=0)
        # local deterministic patch
        if config.GAN_WITH_MASK:
            batch_pos_neg = tf.concat([
                batch_pos_neg,
                tf.tile(mask, [config.BATCH_SIZE * 2, 1, 1, 1])
            ],
                                      axis=3)
        # wgan with gradient penalty
        if config.GAN == 'wgan_gp':
            if not use_local_patch:
                raise Exception('wgan_gp requires global and local patch')
            local_patch_batch_pos_neg = tf.concat(
                [local_patch_batch_pos, local_patch_batch_complete], 0)
            # seperate gan
            pos_neg_local, pos_neg_global = self.build_wgan_discriminator(
                local_patch_batch_pos_neg,
                batch_pos_neg,
                training=training,
                reuse=reuse)
            pos_local, neg_local = tf.split(pos_neg_local, 2)
            pos_global, neg_global = tf.split(pos_neg_global, 2)
            # wgan loss
            g_loss_local, d_loss_local = gan_wgan_loss(pos_local,
                                                       neg_local,
                                                       name='gan/local_gan')
            g_loss_global, d_loss_global = gan_wgan_loss(pos_global,
                                                         neg_global,
                                                         name='gan/global_gan')
            losses[
                'g_loss'] = config.GLOBAL_WGAN_LOSS_ALPHA * g_loss_global + g_loss_local
            losses['d_loss'] = d_loss_global + d_loss_local
            # gp
            interpolates_local = random_interpolates(
                local_patch_batch_pos, local_patch_batch_complete)
            interpolates_global = random_interpolates(batch_pos,
                                                      batch_complete)
            dout_local, dout_global = self.build_wgan_discriminator(
                interpolates_local, interpolates_global, reuse=True)
            # apply penalty
            penalty_local = gradients_penalty(interpolates_local,
                                              dout_local,
                                              mask=local_patch_mask)
            penalty_global = gradients_penalty(interpolates_global,
                                               dout_global,
                                               mask=mask)
            losses['gp_loss'] = config.WGAN_GP_LAMBDA * (penalty_local +
                                                         penalty_global)
            losses['d_loss'] = losses['d_loss'] + losses['gp_loss']
            if summary and not config.PRETRAIN_COARSE_NETWORK:
                gradients_summary(g_loss_local,
                                  batch_predicted,
                                  name='g_loss_local')
                gradients_summary(g_loss_global,
                                  batch_predicted,
                                  name='g_loss_global')
                scalar_summary('convergence/d_loss', losses['d_loss'])
                scalar_summary('convergence/local_d_loss', d_loss_local)
                scalar_summary('convergence/global_d_loss', d_loss_global)
                scalar_summary('gan_wgan_loss/gp_loss', losses['gp_loss'])
                scalar_summary('gan_wgan_loss/gp_penalty_local', penalty_local)
                scalar_summary('gan_wgan_loss/gp_penalty_global',
                               penalty_global)
        elif config.GAN == 'sngan':
            if use_local_patch:
                raise Exception(
                    'sngan incompatible with global and local patch')
            pos_neg = self.build_sngan_discriminator(batch_pos_neg,
                                                     name='discriminator',
                                                     reuse=reuse)
            pos, neg = tf.split(pos_neg, 2)
            g_loss, d_loss = gan_hinge_loss(pos, neg)
            losses['g_loss'] = g_loss
            losses['d_loss'] = d_loss
            if summary and not config.PRETRAIN_COARSE_NETWORK:
                gradients_summary(g_loss, batch_predicted, name='g_loss')
                scalar_summary('convergence/d_loss', losses['d_loss'])
        else:
            losses['g_loss'] = 0

        if summary and not config.PRETRAIN_COARSE_NETWORK:
            # summary the magnitude of gradients from different losses w.r.t. predicted image
            gradients_summary(losses['g_loss'], batch_predicted, name='g_loss')
            gradients_summary(losses['g_loss'], x1, name='g_loss_to_x1')
            gradients_summary(losses['g_loss'], x2, name='g_loss_to_x2')
            gradients_summary(losses['l1_loss'], x1, name='l1_loss_to_x1')
            gradients_summary(losses['l1_loss'], x2, name='l1_loss_to_x2')
            if use_local_patch:
                gradients_summary(losses['ae_loss'], x1, name='ae_loss_to_x1')
                gradients_summary(losses['ae_loss'], x2, name='ae_loss_to_x2')
        if config.PRETRAIN_COARSE_NETWORK:
            losses['g_loss'] = 0
        else:
            losses['g_loss'] = config.GAN_LOSS_ALPHA * losses['g_loss']
        losses['g_loss'] += config.L1_LOSS_ALPHA * losses['l1_loss']
        logger.info('Set L1_LOSS_ALPHA to %f' % config.L1_LOSS_ALPHA)
        logger.info('Set GAN_LOSS_ALPHA to %f' % config.GAN_LOSS_ALPHA)
        if config.AE_LOSS and use_local_patch:
            losses['g_loss'] += config.AE_LOSS_ALPHA * losses['ae_loss']
            logger.info('Set AE_LOSS_ALPHA to %f' % config.AE_LOSS_ALPHA)
        g_vars = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES,
                                   'inpaint_net')
        d_vars = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES,
                                   'discriminator')
        return g_vars, d_vars, losses