コード例 #1
0
ファイル: inpaint_model.py プロジェクト: pasawaya/VLOGDataset
 def build_infer_graph(self, batch_data, config, bbox=None, name='val'):
     """
     """
     config.MAX_DELTA_HEIGHT = 0
     config.MAX_DELTA_WIDTH = 0
     if bbox is None:
         bbox = random_bbox(config)
     mask = bbox2mask(bbox, config, name=name + 'mask_c')
     batch_pos = batch_data / 127.5 - 1.
     edges = None
     batch_incomplete = batch_pos * (1. - mask)
     # inpaint
     x1, x2, offset_flow = self.build_inpaint_net(batch_incomplete,
                                                  mask,
                                                  config,
                                                  reuse=True,
                                                  training=False,
                                                  padding=config.PADDING)
     if config.PRETRAIN_COARSE_NETWORK:
         batch_predicted = x1
     else:
         batch_predicted = x2
     # apply mask and reconstruct
     batch_complete = batch_predicted * mask + batch_incomplete * (1. -
                                                                   mask)
     # global image visualization
     viz_img = [batch_pos, batch_incomplete, batch_complete]
     if offset_flow is not None:
         viz_img.append(
             resize(offset_flow,
                    scale=4,
                    func=tf.image.resize_nearest_neighbor))
     return batch_complete
コード例 #2
0
    def build_infer_graph(self, FLAGS, batch_data, bbox=None, name='val'):
        """
        """
        if FLAGS.guided:
            batch_data, edge = batch_data
            edge = edge[:, :, :, 0:1] / 255.
            edge = tf.cast(edge > FLAGS.edge_threshold, tf.float32)
        regular_mask = bbox2mask(FLAGS, bbox, name='mask_c')
        irregular_mask = brush_stroke_mask(FLAGS, name='mask_c')
        mask = tf.cast(
            tf.logical_or(
                tf.cast(irregular_mask, tf.bool),
                tf.cast(regular_mask, tf.bool),
            ), tf.float32)

        batch_pos = batch_data / 127.5 - 1.
        batch_incomplete = batch_pos * (1. - mask)
        if FLAGS.guided:
            edge = edge * mask
            xin = tf.concat([batch_incomplete, edge], axis=3)
        else:
            xin = batch_incomplete
        # inpaint
        x1, x2, offset_flow = self.build_inpaint_net(xin,
                                                     mask,
                                                     reuse=True,
                                                     training=False,
                                                     padding=FLAGS.padding)
        batch_predicted = x2
        # apply mask and reconstruct
        batch_complete = batch_predicted * mask + batch_incomplete * (1. -
                                                                      mask)
        # global image visualization
        if FLAGS.guided:
            viz_img = [batch_pos, batch_incomplete + edge, batch_complete]
        else:
            viz_img = [batch_pos, batch_incomplete, batch_complete]
        if offset_flow is not None:
            viz_img.append(
                resize(offset_flow, scale=4, func=tf.image.resize_bilinear))
        images_summary(tf.concat(viz_img, axis=2),
                       name + '_raw_incomplete_complete', FLAGS.viz_max_out)
        return batch_complete
コード例 #3
0
ファイル: inpaint_model.py プロジェクト: pasawaya/VLOGDataset
    def build_graph_with_losses(self,
                                batch_data,
                                config,
                                training=True,
                                reuse=False):
        batch_pos = batch_data / 127.5 - 1.
        # generate mask, 1 represents masked point
        bbox = random_bbox(config)
        mask = bbox2mask(bbox, config, name='mask_c')
        batch_incomplete = batch_pos * (1. - mask)
        x1, x2, offset_flow = self.build_inpaint_net(batch_incomplete,
                                                     mask,
                                                     config,
                                                     reuse=reuse,
                                                     training=training,
                                                     padding=config.PADDING)
        if config.PRETRAIN_COARSE_NETWORK:
            batch_predicted = x1
        else:
            batch_predicted = x2
        losses = {}
        # apply mask and complete image
        batch_complete = batch_predicted * mask + batch_incomplete * (1. -
                                                                      mask)
        # local patches
        local_patch_batch_pos = local_patch(batch_pos, bbox)
        local_patch_batch_predicted = local_patch(batch_predicted, bbox)
        local_patch_x1 = local_patch(x1, bbox)
        local_patch_x2 = local_patch(x2, bbox)
        local_patch_batch_complete = local_patch(batch_complete, bbox)
        local_patch_mask = local_patch(mask, bbox)
        l1_alpha = config.COARSE_L1_ALPHA
        losses['l1_loss'] = l1_alpha * tf.reduce_mean(
            tf.abs(local_patch_batch_pos - local_patch_x1) *
            spatial_discounting_mask(config))
        if not config.PRETRAIN_COARSE_NETWORK:
            losses['l1_loss'] += tf.reduce_mean(
                tf.abs(local_patch_batch_pos - local_patch_x2) *
                spatial_discounting_mask(config))
        losses['ae_loss'] = l1_alpha * tf.reduce_mean(
            tf.abs(batch_pos - x1) * (1. - mask))
        if not config.PRETRAIN_COARSE_NETWORK:
            losses['ae_loss'] += tf.reduce_mean(
                tf.abs(batch_pos - x2) * (1. - mask))
        losses['ae_loss'] /= tf.reduce_mean(1. - mask)

        # gan
        batch_pos_neg = tf.concat([batch_pos, batch_complete], axis=0)
        # local deterministic patch
        local_patch_batch_pos_neg = tf.concat(
            [local_patch_batch_pos, local_patch_batch_complete], 0)
        if config.GAN_WITH_MASK:
            batch_pos_neg = tf.concat([
                batch_pos_neg,
                tf.tile(mask, [config.BATCH_SIZE * 2, 1, 1, 1])
            ],
                                      axis=3)
        # wgan with gradient penalty
        if config.GAN == 'wgan_gp':
            # seperate gan
            pos_neg_local, pos_neg_global = self.build_wgan_discriminator(
                local_patch_batch_pos_neg,
                batch_pos_neg,
                training=training,
                reuse=reuse)
            pos_local, neg_local = tf.split(pos_neg_local, 2)
            pos_global, neg_global = tf.split(pos_neg_global, 2)
            # wgan loss
            g_loss_local, d_loss_local = gan_wgan_loss(pos_local,
                                                       neg_local,
                                                       name='gan/local_gan')
            g_loss_global, d_loss_global = gan_wgan_loss(pos_global,
                                                         neg_global,
                                                         name='gan/global_gan')
            losses[
                'g_loss'] = config.GLOBAL_WGAN_LOSS_ALPHA * g_loss_global + g_loss_local
            losses['d_loss'] = d_loss_global + d_loss_local
            # gp
            interpolates_local = random_interpolates(
                local_patch_batch_pos, local_patch_batch_complete)
            interpolates_global = random_interpolates(batch_pos,
                                                      batch_complete)
            dout_local, dout_global = self.build_wgan_discriminator(
                interpolates_local, interpolates_global, reuse=True)
            # apply penalty
            penalty_local = gradients_penalty(interpolates_local,
                                              dout_local,
                                              mask=local_patch_mask)
            penalty_global = gradients_penalty(interpolates_global,
                                               dout_global,
                                               mask=mask)
            losses['gp_loss'] = config.WGAN_GP_LAMBDA * (penalty_local +
                                                         penalty_global)
            losses['d_loss'] = losses['d_loss'] + losses['gp_loss']

        if config.PRETRAIN_COARSE_NETWORK:
            losses['g_loss'] = 0
        else:
            losses['g_loss'] = config.GAN_LOSS_ALPHA * losses['g_loss']
        losses['g_loss'] += config.L1_LOSS_ALPHA * losses['l1_loss']
        if config.AE_LOSS:
            losses['g_loss'] += config.AE_LOSS_ALPHA * losses['ae_loss']
        g_vars = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES,
                                   'inpaint_net')
        d_vars = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES,
                                   'discriminator')
        return g_vars, d_vars, losses
コード例 #4
0
ファイル: inpaint_model.py プロジェクト: XingToMax/vilbs
    def build_graph_with_losses(self,
                                batch_data,
                                config,
                                training=True,
                                summary=False,
                                reuse=False):

        batch_pos = batch_data / 127.5 - 1.
        # generate mask, 1 represents masked point
        bbox = random_bbox(config)
        mask = bbox2mask(bbox, config, name='mask_c')
        batch_incomplete = batch_pos * (1. - mask)
        x1, x2, offset_flow = self.build_inpaint_net(batch_incomplete,
                                                     mask,
                                                     config,
                                                     reuse=reuse,
                                                     training=training,
                                                     padding=config.PADDING)
        if config.PRETRAIN_COARSE_NETWORK:
            batch_predicted = x1
            logger.info('Set batch_predicted to x1.')
        else:
            batch_predicted = x2
            logger.info('Set batch_predicted to x2.')
        losses = {}
        # apply mask and complete image
        batch_complete = batch_predicted * mask + batch_incomplete * (1. -
                                                                      mask)
        # local patches
        local_patch_batch_pos = local_patch(batch_pos, bbox)
        local_patch_batch_predicted = local_patch(batch_predicted, bbox)
        local_patch_x1 = local_patch(x1, bbox)
        local_patch_x2 = local_patch(x2, bbox)
        local_patch_batch_complete = local_patch(batch_complete, bbox)
        local_patch_mask = local_patch(mask, bbox)
        l1_alpha = config.COARSE_L1_ALPHA
        losses['l1_loss'] = l1_alpha * tf.reduce_mean(
            tf.abs(local_patch_batch_pos - local_patch_x1) *
            spatial_discounting_mask(config))
        if not config.PRETRAIN_COARSE_NETWORK:
            losses['l1_loss'] += tf.reduce_mean(
                tf.abs(local_patch_batch_pos - local_patch_x2) *
                spatial_discounting_mask(config))
        losses['ae_loss'] = l1_alpha * tf.reduce_mean(
            tf.abs(batch_pos - x1) * (1. - mask))
        if not config.PRETRAIN_COARSE_NETWORK:
            losses['ae_loss'] += tf.reduce_mean(
                tf.abs(batch_pos - x2) * (1. - mask))
        losses['ae_loss'] /= tf.reduce_mean(1. - mask)
        if summary:
            scalar_summary('losses/l1_loss', losses['l1_loss'])
            scalar_summary('losses/ae_loss', losses['ae_loss'])
            viz_img = [batch_pos, batch_incomplete, batch_complete]
            if offset_flow is not None:
                viz_img.append(
                    resize(offset_flow,
                           scale=4,
                           func=tf.image.resize_nearest_neighbor))
            images_summary(tf.concat(viz_img, axis=2),
                           'raw_incomplete_predicted_complete',
                           config.VIZ_MAX_OUT)

        # gan
        batch_pos_neg = tf.concat([batch_pos, batch_complete], axis=0)
        # local deterministic patch
        local_patch_batch_pos_neg = tf.concat(
            [local_patch_batch_pos, local_patch_batch_complete], 0)
        if config.GAN_WITH_MASK:
            batch_pos_neg = tf.concat([
                batch_pos_neg,
                tf.tile(mask, [config.BATCH_SIZE * 2, 1, 1, 1])
            ],
                                      axis=3)
        # wgan with gradient penalty
        if config.GAN == 'wgan_gp':
            # seperate gan
            pos_neg_local, pos_neg_global = self.build_wgan_discriminator(
                local_patch_batch_pos_neg,
                batch_pos_neg,
                training=training,
                reuse=reuse)
            pos_local, neg_local = tf.split(pos_neg_local, 2)
            pos_global, neg_global = tf.split(pos_neg_global, 2)
            # wgan loss
            g_loss_local, d_loss_local = gan_wgan_loss(pos_local,
                                                       neg_local,
                                                       name='gan/local_gan')
            g_loss_global, d_loss_global = gan_wgan_loss(pos_global,
                                                         neg_global,
                                                         name='gan/global_gan')
            losses[
                'g_loss'] = config.GLOBAL_WGAN_LOSS_ALPHA * g_loss_global + g_loss_local
            losses['d_loss'] = d_loss_global + d_loss_local
            # gp
            interpolates_local = random_interpolates(
                local_patch_batch_pos, local_patch_batch_complete)
            interpolates_global = random_interpolates(batch_pos,
                                                      batch_complete)
            dout_local, dout_global = self.build_wgan_discriminator(
                interpolates_local, interpolates_global, reuse=True)
            # apply penalty
            penalty_local = gradients_penalty(interpolates_local,
                                              dout_local,
                                              mask=local_patch_mask)
            penalty_global = gradients_penalty(interpolates_global,
                                               dout_global,
                                               mask=mask)
            losses['gp_loss'] = config.WGAN_GP_LAMBDA * (penalty_local +
                                                         penalty_global)
            losses['d_loss'] = losses['d_loss'] + losses['gp_loss']
            if summary and not config.PRETRAIN_COARSE_NETWORK:
                gradients_summary(g_loss_local,
                                  batch_predicted,
                                  name='g_loss_local')
                gradients_summary(g_loss_global,
                                  batch_predicted,
                                  name='g_loss_global')
                scalar_summary('convergence/d_loss', losses['d_loss'])
                scalar_summary('convergence/local_d_loss', d_loss_local)
                scalar_summary('convergence/global_d_loss', d_loss_global)
                scalar_summary('gan_wgan_loss/gp_loss', losses['gp_loss'])
                scalar_summary('gan_wgan_loss/gp_penalty_local', penalty_local)
                scalar_summary('gan_wgan_loss/gp_penalty_global',
                               penalty_global)

        if summary and not config.PRETRAIN_COARSE_NETWORK:
            # summary the magnitude of gradients from different losses w.r.t. predicted image
            gradients_summary(losses['g_loss'], batch_predicted, name='g_loss')
            gradients_summary(losses['g_loss'], x1, name='g_loss_to_x1')
            gradients_summary(losses['g_loss'], x2, name='g_loss_to_x2')
            gradients_summary(losses['l1_loss'], x1, name='l1_loss_to_x1')
            gradients_summary(losses['l1_loss'], x2, name='l1_loss_to_x2')
            gradients_summary(losses['ae_loss'], x1, name='ae_loss_to_x1')
            gradients_summary(losses['ae_loss'], x2, name='ae_loss_to_x2')
        if config.PRETRAIN_COARSE_NETWORK:
            losses['g_loss'] = 0
        else:
            losses['g_loss'] = config.GAN_LOSS_ALPHA * losses['g_loss']
        losses['g_loss'] += config.L1_LOSS_ALPHA * losses['l1_loss']
        logger.info('Set L1_LOSS_ALPHA to %f' % config.L1_LOSS_ALPHA)
        logger.info('Set GAN_LOSS_ALPHA to %f' % config.GAN_LOSS_ALPHA)
        if config.AE_LOSS:
            losses['g_loss'] += config.AE_LOSS_ALPHA * losses['ae_loss']
            logger.info('Set AE_LOSS_ALPHA to %f' % config.AE_LOSS_ALPHA)
        g_vars = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES,
                                   'inpaint_net')
        d_vars = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES,
                                   'discriminator')
        return g_vars, d_vars, losses
コード例 #5
0
    def build_graph_with_losses(self,
                                FLAGS,
                                batch_data,
                                training=True,
                                summary=False,
                                reuse=False):
        if FLAGS.guided:
            batch_data, edge = batch_data
            edge = edge[:, :, :, 0:1] / 255.
            edge = tf.cast(edge > FLAGS.edge_threshold, tf.float32)
        batch_pos = batch_data / 127.5 - 1.
        # generate mask, 1 represents masked point
        bbox = random_bbox(FLAGS)
        regular_mask = bbox2mask(FLAGS, bbox, name='mask_c')
        irregular_mask = brush_stroke_mask(FLAGS, name='mask_c')
        mask = tf.cast(
            tf.logical_or(
                tf.cast(irregular_mask, tf.bool),
                tf.cast(regular_mask, tf.bool),
            ), tf.float32)

        batch_incomplete = batch_pos * (1. - mask)
        if FLAGS.guided:
            edge = edge * mask
            xin = tf.concat([batch_incomplete, edge], axis=3)
        else:
            xin = batch_incomplete
        x1, x2, offset_flow = self.build_inpaint_net(xin,
                                                     mask,
                                                     reuse=reuse,
                                                     training=training,
                                                     padding=FLAGS.padding)
        batch_predicted = x2
        losses = {}
        # apply mask and complete image
        batch_complete = batch_predicted * mask + batch_incomplete * (1. -
                                                                      mask)
        # local patches
        losses['ae_loss'] = FLAGS.l1_loss_alpha * tf.reduce_mean(
            tf.abs(batch_pos - x1))
        losses['ae_loss'] += FLAGS.l1_loss_alpha * tf.reduce_mean(
            tf.abs(batch_pos - x2))
        if summary:
            scalar_summary('losses/ae_loss', losses['ae_loss'])
            if FLAGS.guided:
                viz_img = [batch_pos, batch_incomplete + edge, batch_complete]
            else:
                viz_img = [batch_pos, batch_incomplete, batch_complete]
            if offset_flow is not None:
                viz_img.append(
                    resize(offset_flow, scale=4,
                           func=tf.image.resize_bilinear))
            images_summary(tf.concat(viz_img, axis=2),
                           'raw_incomplete_predicted_complete',
                           FLAGS.viz_max_out)

        # gan
        batch_pos_neg = tf.concat([batch_pos, batch_complete], axis=0)
        if FLAGS.gan_with_mask:
            batch_pos_neg = tf.concat([
                batch_pos_neg,
                tf.tile(mask, [FLAGS.batch_size * 2, 1, 1, 1])
            ],
                                      axis=3)
        if FLAGS.guided:
            # conditional GANs
            batch_pos_neg = tf.concat(
                [batch_pos_neg, tf.tile(edge, [2, 1, 1, 1])], axis=3)
        # wgan with gradient penalty
        if FLAGS.gan == 'sngan':
            pos_neg = self.build_gan_discriminator(batch_pos_neg,
                                                   training=training,
                                                   reuse=reuse)
            pos, neg = tf.split(pos_neg, 2)
            g_loss, d_loss = gan_hinge_loss(pos, neg)
            losses['g_loss'] = g_loss
            losses['d_loss'] = d_loss
        else:
            raise NotImplementedError('{} not implemented.'.format(FLAGS.gan))
        if summary:
            # summary the magnitude of gradients from different losses w.r.t. predicted image
            gradients_summary(losses['g_loss'], batch_predicted, name='g_loss')
            gradients_summary(losses['g_loss'], x2, name='g_loss_to_x2')
            # gradients_summary(losses['ae_loss'], x1, name='ae_loss_to_x1')
            gradients_summary(losses['ae_loss'], x2, name='ae_loss_to_x2')
        losses['g_loss'] = FLAGS.gan_loss_alpha * losses['g_loss']
        if FLAGS.ae_loss:
            losses['g_loss'] += losses['ae_loss']
        g_vars = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES,
                                   'inpaint_net')
        d_vars = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES,
                                   'discriminator')
        return g_vars, d_vars, losses