def build_infer_graph(self, batch_data, config, name='val'):

        config.MAX_DELTA_HEIGHT = 0
        config.MAX_DELTA_WIDTH = 0

        mask = bbox2mask(config, name=name + 'mask_c')
        batch_pos = batch_data / 127.5 - 1.
        edges = None
        batch_incomplete = batch_pos * (1. - mask)
        x1, x2, offset_flow = self.build_inpaint_net(batch_incomplete,
                                                     mask,
                                                     config,
                                                     reuse=True,
                                                     training=False,
                                                     padding=config.PADDING)
        if config.PRETRAIN_COARSE_NETWORK:
            batch_predicted = x1
            logger.info('Set batch_predicted to x1.')
        else:
            batch_predicted = x2
            logger.info('Set batch_predicted to x2.')
        batch_complete = batch_predicted * mask + batch_incomplete * (1. -
                                                                      mask)
        viz_img = [batch_pos, batch_incomplete, batch_complete]
        if offset_flow is not None:
            viz_img.append(
                resize(offset_flow,
                       scale=4,
                       func=tf.image.resize_nearest_neighbor))
        images_summary(tf.concat(viz_img, axis=2),
                       name + '_raw_incomplete_complete', config.VIZ_MAX_OUT)
        return batch_complete
Exemple #2
0
 def build_infer_graph(self, batch_data, config, bbox=None, name='val'):
     """
     """
     config.MAX_DELTA_HEIGHT = 0
     config.MAX_DELTA_WIDTH = 0
     if bbox is None:
         bbox = random_bbox(config)
     mask = bbox2mask(bbox, config, name=name+'mask_c')
     batch_pos = batch_data / 127.5 - 1.
     edges = None
     batch_incomplete = batch_pos*(1.-mask)
     # inpaint
     x2, offset_flow = self.build_inpaint_net(
         batch_incomplete, mask, config, reuse=True,
         training=False, padding=config.PADDING)
     batch_predicted = x2
     logger.info('Set batch_predicted to x2.')
     # apply mask and reconstruct
     batch_complete = batch_predicted*mask + batch_incomplete*(1.-mask)
     # global image visualization
     viz_img = [batch_pos, batch_incomplete, batch_complete]
     if offset_flow is not None:
         viz_img.append(
             resize(offset_flow, scale=4,
                    func=tf.image.resize_nearest_neighbor))
     images_summary(
         tf.concat(viz_img, axis=2),
         name+'_raw_incomplete_complete', config.VIZ_MAX_OUT)
     return batch_complete
Exemple #3
0
    def train_step(iter_idx):
        with tf.GradientTape() as gen_tape, tf.GradientTape() as disc_tape:
            batch_pos = next(train_iter)
            bbox = random_bbox(FLAGS)
            regular_mask = bbox2mask(FLAGS, bbox, name='mask_c')
            irregular_mask = brush_stroke_mask(FLAGS, name='mask_c')
            mask = tf.cast(
                tf.logical_or(
                    tf.cast(irregular_mask, tf.bool),
                    tf.cast(regular_mask, tf.bool),
                ), tf.float32)
            # mask is 0-1
            batch_incomplete = batch_pos * (1. - mask)
            xin = batch_incomplete
            x1, x2, offset_flow = G(xin, mask, training=True)
            batch_predicted = x2

            losses = {}
            batch_complete = batch_predicted * \
                mask + batch_incomplete * (1. - mask)
            losses['ae_loss'] = FLAGS['l1_loss_alpha'] * \
                tf.reduce_mean(input_tensor=tf.abs(batch_pos - x1))
            losses['ae_loss'] += FLAGS['l1_loss_alpha'] * \
                tf.reduce_mean(input_tensor=tf.abs(batch_pos - x2))

            batch_pos_neg = tf.concat([batch_pos, batch_complete], axis=0)
            batch_pos_neg = tf.concat([
                batch_pos_neg,
                tf.tile(mask, [FLAGS['batch_size'] * 2, 1, 1, 1])
            ],
                                      axis=3)

            # SNGAN
            pos_neg = D(batch_pos_neg, training=True)
            pos, neg = tf.split(pos_neg, 2)
            g_loss, d_loss = gan_hinge_loss(pos, neg)
            losses['g_loss'] = g_loss
            losses['d_loss'] = d_loss
            losses['g_loss'] = FLAGS['gan_loss_alpha'] * losses['g_loss']
            losses['g_loss'] += losses['ae_loss']
            # return losses
        gradients_of_generator = gen_tape.gradient(losses['g_loss'],
                                                   G.trainable_variables)
        gradients_of_discriminator = disc_tape.gradient(
            losses['d_loss'], D.trainable_variables)
        g_optimizer.apply_gradients(
            zip(gradients_of_generator, G.trainable_variables))
        d_optimizer.apply_gradients(
            zip(gradients_of_discriminator, D.trainable_variables))

        # tensorboard
        if iter_idx > 0 and iter_idx % FLAGS['viz_max_out'] == 0:
            with train_summary_writer.as_default():
                tf.summary.scalar('g_loss', losses['g_loss'], step=iter_idx)
                tf.summary.scalar('d_loss', losses['d_loss'], step=iter_idx)
                img = tf.reshape(batch_complete[0], (-1, 256, 256, 3))
                tf.summary.image("train", img, step=iter_idx)
Exemple #4
0
 def build_infer_graph(self,
                       batch_data,
                       config,
                       bbox=None,
                       name='val',
                       exclusionmask=None):
     """
     """
     config.MAX_DELTA_HEIGHT = 0
     config.MAX_DELTA_WIDTH = 0
     if bbox is None:
         bbox = random_bbox(config)
     mask = bbox2mask(bbox, config, name=name + 'mask_c')
     batch_pos = batch_data / 127.5 - 1.
     edges = None
     batch_incomplete = batch_pos * (1. - mask)
     # inpaint
     x1, x2, offset_flow = self.build_inpaint_net(batch_incomplete,
                                                  mask,
                                                  config,
                                                  reuse=True,
                                                  training=False,
                                                  padding=config.PADDING)
     if config.PRETRAIN_COARSE_NETWORK:
         batch_predicted = x1
         logger.info('Set batch_predicted to x1.')
     else:
         batch_predicted = x2
         logger.info('Set batch_predicted to x2.')
     # apply mask and reconstruct
     batch_complete = batch_predicted * mask + batch_incomplete * (1. -
                                                                   mask)
     # global image visualization
     img_size = [dim for dim in batch_incomplete.shape]
     img_size[2] = 5
     border = tf.zeros(tf.TensorShape(img_size))
     viz_img = [
         border, batch_pos, border, batch_incomplete, border,
         batch_complete, border
     ]
     if not config.PRETRAIN_COARSE_NETWORK:
         batch_complete_coarse = x1 * mask + batch_incomplete * (1. - mask)
         viz_img.append(batch_complete_coarse)
     if offset_flow is not None:
         scale = 2 << len(offset_flow)
         for flow in offset_flow:
             viz_img.append(
                 resize(flow,
                        scale=scale,
                        func=tf.image.resize_nearest_neighbor))
             viz_img.append(border)
             scale >>= 1
     images_summary(tf.concat(viz_img, axis=2),
                    name + '_raw_incomplete_complete', config.VIZ_MAX_OUT)
     return batch_complete
Exemple #5
0
    def build_infer_graph(self, batch_data, config, bbox=None, name='val'):
        """
        """
        config.MAX_DELTA_HEIGHT = 0
        config.MAX_DELTA_WIDTH = 0
        if bbox is None:
            bbox = random_bbox(config)
        mask = bbox2mask(bbox, config, name=name + 'mask_c')
        batch_pos = batch_data / 127.5 - 1.

        edges = None
        batch_incomplete = batch_pos * (1. - mask)
        # inpaint
        if not config.ADD_GRADIENT_BRANCH:
            x1, x2, offset_flow = self.build_inpaint_net(
                batch_incomplete,
                mask,
                config,
                reuse=True,
                training=False,
                padding=config.PADDING)
        else:
            x1, x2, fake_gb, gb, offset_flow = self.build_inpaint_net(
                batch_incomplete,
                mask,
                config,
                reuse=True,
                training=False,
                padding=config.PADDING)
        if config.PRETRAIN_COARSE_NETWORK:
            batch_predicted = x1
            logger.info('Set batch_predicted to x1.')
        else:
            batch_predicted = x2
            logger.info('Set batch_predicted to x2.')
        # apply mask and reconstruct
        batch_complete = batch_predicted * mask + batch_incomplete * (1. -
                                                                      mask)
        # global image visualization
        viz_img = [batch_pos, batch_incomplete, batch_complete]

        if config.ADD_GRADIENT_BRANCH:
            gb_gt = self.get_grad.get_gradient_tf(batch_pos)
            batch_complete_gb = fake_gb * mask + gb_gt * (1. - mask)
            viz_img.extend([gb_gt, batch_complete_gb])

        if offset_flow is not None:
            viz_img.append(
                resize(offset_flow,
                       scale=4,
                       func=tf.image.resize_nearest_neighbor))
        images_summary(tf.concat(viz_img, axis=2),
                       name + '_raw_incomplete_complete', config.VIZ_MAX_OUT)
        return batch_complete
def _create_mask(config, edge=None):
    # num_class = 1  # only for person.
    size = config.IMG_SHAPES[:2]
    if config.FREE_FORM:
        mask = free_form_mask(np, 1, size)
    else:
        bbox = random_bbox(config)
        mask = bbox2mask(bbox, 1, config, np)
    if edge is not None:
        edge = mask * (edge[None, :1] / 255.)
        mask = np.concatenate([mask, edge], axis=1)

    return mask[0]
    def build_infer_graph(self, FLAGS, batch_data, bbox=None, name='val'):
        """
        """
        if FLAGS.guided:
            batch_data, edge = batch_data
            edge = edge[:, :, :, 0:1] / 255.
            edge = tf.cast(edge > FLAGS.edge_threshold, tf.float32)
        mask = brush_stroke_mask(name='mask_c')
        regular_mask = bbox2mask(bbox, name='mask_c')
        irregular_mask = brush_stroke_mask(name='mask_c')
        mask = tf.cast(
            tf.logical_or(
                tf.cast(irregular_mask, tf.bool),
                tf.cast(regular_mask, tf.bool),
            ),
            tf.float32
        )

        batch_pos = batch_data / 127.5 - 1.
        batch_incomplete = batch_pos*(1.-mask)
        if FLAGS.guided:
            edge = edge * mask
            xin = tf.concat([batch_incomplete, edge], axis=3)
        else:
            xin = batch_incomplete
        # inpaint
        x1, x2, offset_flow = self.build_inpaint_net(
            xin, mask, reuse=True,
            training=False, padding=FLAGS.padding)
        batch_predicted = x2
        # apply mask and reconstruct
        batch_complete = batch_predicted*mask + batch_incomplete*(1.-mask)
        # global image visualization
        if FLAGS.guided:
            viz_img = [
                batch_pos,
                batch_incomplete + edge,
                batch_complete]
        else:
            viz_img = [batch_pos, batch_incomplete, batch_complete]
        if offset_flow is not None:
            viz_img.append(
                resize(offset_flow, scale=4,
                       func=tf.image.resize_bilinear))
        images_summary(
            tf.concat(viz_img, axis=2),
            name+'_raw_incomplete_complete', FLAGS.viz_max_out)
        return batch_complete
        def evaluation(trainer):
            it = trainer.updater.get_iterator('test')
            batch_data = it.next()
            batch_data = self.xp.array(batch_data)

            # generate mask, 1 represents masked point
            bbox = (config.HEIGHT // 2, config.WIDTH // 2, config.HEIGHT,
                    config.WIDTH)
            config.MAX_DELTA_HEIGHT = 0
            config.MAX_DELTA_WIDTH = 0
            if bbox is None:
                bbox = random_bbox(config)
            mask = bbox2mask(bbox, batch_data.shape[0], config, self.xp)
            batch_pos = batch_data / 127.5 - 1.
            batch_incomplete = batch_pos * (1. - mask)
            # inpaint
            x1, x2, offset_flow = self.inpaintnet(batch_incomplete, mask,
                                                  config)
            if config.PRETRAIN_COARSE_NETWORK:
                batch_predicted = x1
            else:
                batch_predicted = x2
            # apply mask and reconstruct
            batch_complete = batch_predicted * mask + batch_incomplete * (1. -
                                                                          mask)
            # visualization
            viz_img = [batch_pos, batch_incomplete + mask, batch_complete.data]
            if offset_flow is not None:
                viz_img.append(F.unpooling_2d(offset_flow, 4).data)
            batch_w = len(viz_img)
            batch_h = viz_img[0].shape[0]
            viz_img = self.xp.concatenate(viz_img, axis=0)
            viz_img = batch_postprocess_images(viz_img, batch_w, batch_h)
            viz_img = cuda.to_cpu(viz_img)
            Image.fromarray(viz_img).save(test_image_folder + "/iter_" +
                                          str(trainer.updater.iteration) +
                                          ".jpg")
Exemple #9
0
    def build_graph_with_losses(
            self, FLAGS, batch_data, training=True, summary=False,
            reuse=False):
        if FLAGS.guided:
            batch_data, edge = batch_data
            edge = edge[:, :, :, 0:1] / 255.
            edge = tf.cast(edge > FLAGS.edge_threshold, tf.float32)
        batch_pos = batch_data / 127.5 - 1.
        # generate mask, 1 represents masked point
        bbox = random_bbox(FLAGS)
        regular_mask = bbox2mask(FLAGS, bbox, name='mask_c')
        irregular_mask = brush_stroke_mask(FLAGS, name='mask_c')
        mask = tf.cast(
            tf.logical_or(
                tf.cast(irregular_mask, tf.bool),
                tf.cast(regular_mask, tf.bool),
            ),
            tf.float32
        )

        batch_incomplete = batch_pos*(1.-mask)
        if FLAGS.guided:
            edge = edge * mask
            xin = tf.concat([batch_incomplete, edge], axis=3)
        else:
            xin = batch_incomplete
        x1, x2, offset_flow = self.build_inpaint_net(
            xin, mask, reuse=reuse, training=training,
            padding=FLAGS.padding)
        batch_predicted = x2
        losses = {}
        # apply mask and complete image
        batch_complete = batch_predicted*mask + batch_incomplete*(1.-mask)
        # local patches
        losses['ae_loss'] = FLAGS.l1_loss_alpha * tf.reduce_mean(tf.abs(batch_pos - x1))
        losses['ae_loss'] += FLAGS.l1_loss_alpha * tf.reduce_mean(tf.abs(batch_pos - x2))
        if summary:
            scalar_summary('losses/ae_loss', losses['ae_loss'])
            if FLAGS.guided:
                viz_img = [
                    batch_pos,
                    batch_incomplete + edge,
                    batch_complete]
            else:
                viz_img = [batch_pos, batch_incomplete, batch_complete]
            if offset_flow is not None:
                viz_img.append(
                    resize(offset_flow, scale=4,
                           func=tf.image.resize_bilinear))
            images_summary(
                tf.concat(viz_img, axis=2),
                'raw_incomplete_predicted_complete', FLAGS.viz_max_out)

        # gan
        batch_pos_neg = tf.concat([batch_pos, batch_complete], axis=0)
        if FLAGS.gan_with_mask:
            batch_pos_neg = tf.concat([batch_pos_neg, tf.tile(mask, [FLAGS.batch_size*2, 1, 1, 1])], axis=3)
        if FLAGS.guided:
            # conditional GANs
            batch_pos_neg = tf.concat([batch_pos_neg, tf.tile(edge, [2, 1, 1, 1])], axis=3)
        # wgan with gradient penalty
        if FLAGS.gan == 'sngan':
            pos_neg = self.build_gan_discriminator(batch_pos_neg, training=training, reuse=reuse)
            pos, neg = tf.split(pos_neg, 2)
            g_loss, d_loss = gan_hinge_loss(pos, neg)
            losses['g_loss'] = g_loss
            losses['d_loss'] = d_loss
        else:
            raise NotImplementedError('{} not implemented.'.format(FLAGS.gan))
        if summary:
            # summary the magnitude of gradients from different losses w.r.t. predicted image
            gradients_summary(losses['g_loss'], batch_predicted, name='g_loss')
            gradients_summary(losses['g_loss'], x2, name='g_loss_to_x2')
            # gradients_summary(losses['ae_loss'], x1, name='ae_loss_to_x1')
            gradients_summary(losses['ae_loss'], x2, name='ae_loss_to_x2')
        losses['g_loss'] = FLAGS.gan_loss_alpha * losses['g_loss']
        if FLAGS.ae_loss:
            losses['g_loss'] += losses['ae_loss']
        g_vars = tf.get_collection(
            tf.GraphKeys.TRAINABLE_VARIABLES, 'inpaint_net')
        d_vars = tf.get_collection(
            tf.GraphKeys.TRAINABLE_VARIABLES, 'discriminator')
        return g_vars, d_vars, losses
    def build_graph_with_losses(self,
                                batch_data,
                                config,
                                training=True,
                                summary=False,
                                reuse=False):
        batch_pos = batch_data / 127.5 - 1.

        mask = bbox2mask(config, name='mask_c')
        batch_incomplete = batch_pos * (1. - mask)
        x1, x2, offset_flow = self.build_inpaint_net(batch_incomplete,
                                                     mask,
                                                     config,
                                                     reuse=reuse,
                                                     training=training,
                                                     padding=config.PADDING)
        if config.PRETRAIN_COARSE_NETWORK:
            batch_predicted = x1
            logger.info('Set batch_predicted to x1.')
        else:
            batch_predicted = x2
            logger.info('Set batch_predicted to x2.')
        losses = {}
        batch_complete = batch_predicted * mask + batch_incomplete * (1. -
                                                                      mask)
        local_patch_batch_pos = local_patch(batch_pos, mask)
        local_patch_batch_predicted = local_patch(batch_predicted, mask)
        local_patch_x1 = local_patch(x1, mask)
        local_patch_x2 = local_patch(x2, mask)
        local_patch_batch_complete = local_patch(batch_complete, mask)
        l1_alpha = config.COARSE_L1_ALPHA
        losses['l1_loss'] = l1_alpha * tf.reduce_mean(
            tf.abs(local_patch_batch_pos -
                   local_patch_x1))  #*spatial_discounting_mask(config))
        if not config.PRETRAIN_COARSE_NETWORK:
            losses['l1_loss'] += tf.reduce_mean(
                tf.abs(local_patch_batch_pos -
                       local_patch_x2))  #*spatial_discounting_mask(config))
        losses['ae_loss'] = l1_alpha * tf.reduce_mean(
            tf.abs(batch_pos - x1) * (1. - mask))
        if not config.PRETRAIN_COARSE_NETWORK:
            losses['ae_loss'] += tf.reduce_mean(
                tf.abs(batch_pos - x2) * (1. - mask))
        losses['ae_loss'] /= tf.reduce_mean(1. - mask)
        if summary:
            scalar_summary('losses/l1_loss', losses['l1_loss'])
            scalar_summary('losses/ae_loss', losses['ae_loss'])
            viz_img = [batch_pos, batch_incomplete, batch_complete]
            if offset_flow is not None:
                viz_img.append(
                    resize(offset_flow,
                           scale=4,
                           func=tf.image.resize_nearest_neighbor))
            images_summary(tf.concat(viz_img, axis=2),
                           'raw_incomplete_predicted_complete',
                           config.VIZ_MAX_OUT)

        batch_pos_neg = tf.concat([batch_pos, batch_complete], axis=0)
        local_patch_batch_pos_neg = tf.concat(
            [local_patch_batch_pos, local_patch_batch_complete], 0)
        if config.GAN_WITH_MASK:
            batch_pos_neg = tf.concat([
                batch_pos_neg,
                tf.tile(mask, [config.BATCH_SIZE * 2, 1, 1, 1])
            ],
                                      axis=3)

        if config.GAN == 'snpatch_gan':
            pos_neg = self.build_SNGAN_discriminator(local_patch_batch_pos_neg,
                                                     training=training,
                                                     reuse=reuse)
            pos, neg = tf.split(pos_neg, 2)
            sn_gloss, sn_dloss = self.gan_hinge_loss(pos,
                                                     neg,
                                                     name="gan/hinge_loss")
            losses['g_loss'] = config.GLOBAL_WGAN_LOSS_ALPHA * sn_gloss
            losses['d_loss'] = sn_dloss
            interpolates = random_interpolates(a1, a2)
            dout = self.build_SNGAN_discriminator(interpolates, reuse=True)
            penalty = gradients_penalty(interpolates, dout, mask=mask)
            losses['gp_loss'] = config.WGAN_GP_LAMBDA * penalty
            losses['d_loss'] = losses['d_loss'] + losses['gp_loss']
            if summary and not config.PRETRAIN_COARSE_NETWORK:
                gradients_summary(sn_gloss,
                                  batch_predicted,
                                  name='g_loss_local')
                scalar_summary('convergence/d_loss', losses['d_loss'])
                scalar_summary('convergence/local_d_loss', sn_dloss)
                scalar_summary('gan_wgan_loss/gp_loss', losses['gp_loss'])
                scalar_summary('gan_wgan_loss/gp_penalty_local', penalty)

        if summary and not config.PRETRAIN_COARSE_NETWORK:
            gradients_summary(losses['g_loss'], batch_predicted, name='g_loss')
            gradients_summary(losses['g_loss'], x1, name='g_loss_to_x1')
            gradients_summary(losses['g_loss'], x2, name='g_loss_to_x2')
            gradients_summary(losses['l1_loss'], x1, name='l1_loss_to_x1')
            gradients_summary(losses['l1_loss'], x2, name='l1_loss_to_x2')
            gradients_summary(losses['ae_loss'], x1, name='ae_loss_to_x1')
            gradients_summary(losses['ae_loss'], x2, name='ae_loss_to_x2')
        if config.PRETRAIN_COARSE_NETWORK:
            losses['g_loss'] = 0
        else:
            losses['g_loss'] = config.GAN_LOSS_ALPHA * losses['g_loss']
        losses['g_loss'] += config.L1_LOSS_ALPHA * losses['l1_loss']
        logger.info('Set L1_LOSS_ALPHA to %f' % config.L1_LOSS_ALPHA)
        logger.info('Set GAN_LOSS_ALPHA to %f' % config.GAN_LOSS_ALPHA)
        if config.AE_LOSS:
            losses['g_loss'] += config.AE_LOSS_ALPHA * losses['ae_loss']
            logger.info('Set AE_LOSS_ALPHA to %f' % config.AE_LOSS_ALPHA)
        g_vars = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES,
                                   'inpaint_net')
        d_vars = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES,
                                   'discriminator')
        return g_vars, d_vars, losses
    def build_graph_with_losses(self, batch_data, config, training=True,
                                summary=False, reuse=False):
        batch_pos = batch_data / 127.5 - 1.
        # generate mask, 1 represents masked point
        bbox = random_bbox(config)
        mask = bbox2mask(bbox, config, name='mask_c')
        batch_incomplete = batch_pos*(1.-mask)
        x1, x2, offset_flow = self.build_inpaint_net(
            batch_incomplete, mask, config, reuse=reuse, training=training,
            padding=config.PADDING)
        if config.PRETRAIN_COARSE_NETWORK:
            batch_predicted = x1
            logger.info('Set batch_predicted to x1.')
        else:
            batch_predicted = x2
            logger.info('Set batch_predicted to x2.')
        losses = {}
        # apply mask and complete image
        batch_complete = batch_predicted*mask + batch_incomplete*(1.-mask)
        # local patches
        local_patch_batch_pos = local_patch(batch_pos, bbox)
        local_patch_batch_predicted = local_patch(batch_predicted, bbox)
        local_patch_x1 = local_patch(x1, bbox)
        local_patch_x2 = local_patch(x2, bbox)
        local_patch_batch_complete = local_patch(batch_complete, bbox)
        local_patch_mask = local_patch(mask, bbox)
        l1_alpha = config.COARSE_L1_ALPHA
        losses['l1_loss'] = l1_alpha * tf.reduce_mean(tf.abs(local_patch_batch_pos - local_patch_x1)*spatial_discounting_mask(config))
        if not config.PRETRAIN_COARSE_NETWORK:
            losses['l1_loss'] += tf.reduce_mean(tf.abs(local_patch_batch_pos - local_patch_x2)*spatial_discounting_mask(config))
        losses['ae_loss'] = l1_alpha * tf.reduce_mean(tf.abs(batch_pos - x1) * (1.-mask))
        if not config.PRETRAIN_COARSE_NETWORK:
            losses['ae_loss'] += tf.reduce_mean(tf.abs(batch_pos - x2) * (1.-mask))
        losses['ae_loss'] /= tf.reduce_mean(1.-mask)
        if summary:
            scalar_summary('losses/l1_loss', losses['l1_loss'])
            scalar_summary('losses/ae_loss', losses['ae_loss'])
            viz_img = [batch_pos, batch_incomplete, batch_complete]
            if offset_flow is not None:
                viz_img.append(
                    resize(offset_flow, scale=4,
                           func=tf.image.resize_nearest_neighbor))
            images_summary(
                tf.concat(viz_img, axis=2),
                'raw_incomplete_predicted_complete', config.VIZ_MAX_OUT)

        # gan
        batch_pos_neg = tf.concat([batch_pos, batch_complete], axis=0)
        # local deterministic patch
        local_patch_batch_pos_neg = tf.concat([local_patch_batch_pos, local_patch_batch_complete], 0)
        if config.GAN_WITH_MASK:
            batch_pos_neg = tf.concat([batch_pos_neg, tf.tile(mask, [config.BATCH_SIZE*2, 1, 1, 1])], axis=3)
        # wgan with gradient penalty
        if config.GAN == 'wgan_gp':
            # seperate gan
            pos_neg_local, pos_neg_global = self.build_wgan_discriminator(local_patch_batch_pos_neg, batch_pos_neg, training=training, reuse=reuse)
            pos_local, neg_local = tf.split(pos_neg_local, 2)
            pos_global, neg_global = tf.split(pos_neg_global, 2)
            # wgan loss
            g_loss_local, d_loss_local = gan_wgan_loss(pos_local, neg_local, name='gan/local_gan')
            g_loss_global, d_loss_global = gan_wgan_loss(pos_global, neg_global, name='gan/global_gan')
            losses['g_loss'] = config.GLOBAL_WGAN_LOSS_ALPHA * g_loss_global + g_loss_local
            losses['d_loss'] = d_loss_global + d_loss_local
            # gp
            interpolates_local = random_interpolates(local_patch_batch_pos, local_patch_batch_complete)
            interpolates_global = random_interpolates(batch_pos, batch_complete)
            dout_local, dout_global = self.build_wgan_discriminator(
                interpolates_local, interpolates_global, reuse=True)
            # apply penalty
            penalty_local = gradients_penalty(interpolates_local, dout_local, mask=local_patch_mask)
            penalty_global = gradients_penalty(interpolates_global, dout_global, mask=mask)
            losses['gp_loss'] = config.WGAN_GP_LAMBDA * (penalty_local + penalty_global)
            losses['d_loss'] = losses['d_loss'] + losses['gp_loss']
            if summary and not config.PRETRAIN_COARSE_NETWORK:
                gradients_summary(g_loss_local, batch_predicted, name='g_loss_local')
                gradients_summary(g_loss_global, batch_predicted, name='g_loss_global')
                scalar_summary('convergence/d_loss', losses['d_loss'])
                scalar_summary('convergence/local_d_loss', d_loss_local)
                scalar_summary('convergence/global_d_loss', d_loss_global)
                scalar_summary('gan_wgan_loss/gp_loss', losses['gp_loss'])
                scalar_summary('gan_wgan_loss/gp_penalty_local', penalty_local)
                scalar_summary('gan_wgan_loss/gp_penalty_global', penalty_global)

        if summary and not config.PRETRAIN_COARSE_NETWORK:
            # summary the magnitude of gradients from different losses w.r.t. predicted image
            gradients_summary(losses['g_loss'], batch_predicted, name='g_loss')
            gradients_summary(losses['g_loss'], x1, name='g_loss_to_x1')
            gradients_summary(losses['g_loss'], x2, name='g_loss_to_x2')
            gradients_summary(losses['l1_loss'], x1, name='l1_loss_to_x1')
            gradients_summary(losses['l1_loss'], x2, name='l1_loss_to_x2')
            gradients_summary(losses['ae_loss'], x1, name='ae_loss_to_x1')
            gradients_summary(losses['ae_loss'], x2, name='ae_loss_to_x2')
        if config.PRETRAIN_COARSE_NETWORK:
            losses['g_loss'] = 0
        else:
            losses['g_loss'] = config.GAN_LOSS_ALPHA * losses['g_loss']
        losses['g_loss'] += config.L1_LOSS_ALPHA * losses['l1_loss']
        logger.info('Set L1_LOSS_ALPHA to %f' % config.L1_LOSS_ALPHA)
        logger.info('Set GAN_LOSS_ALPHA to %f' % config.GAN_LOSS_ALPHA)
        if config.AE_LOSS:
            losses['g_loss'] += config.AE_LOSS_ALPHA * losses['ae_loss']
            logger.info('Set AE_LOSS_ALPHA to %f' % config.AE_LOSS_ALPHA)
        g_vars = tf.get_collection(
            tf.GraphKeys.TRAINABLE_VARIABLES, 'inpaint_net')
        d_vars = tf.get_collection(
            tf.GraphKeys.TRAINABLE_VARIABLES, 'discriminator')
        return g_vars, d_vars, losses
Exemple #12
0
    def build_graph_with_losses(self,
                                batch_data,
                                config,
                                training=True,
                                summary=False,
                                reuse=False,
                                exclusionmask=None,
                                mask=None):
        batch_pos = batch_data / 127.5 - 1.
        # generate mask, 1 represents masked point
        use_local_patch = False
        if mask is None:
            bbox = random_bbox(config)
            mask = bbox2mask(bbox, config, name='mask_c')
            if config.GAN == 'wgan_gp':
                use_local_patch = True
        else:
            #bbox = (0, 0, config.IMG_SHAPES[0], config.IMG_SHAPES[1])
            mask = tf.cast(tf.less(0.5, mask[:, :, :, 0:1]), tf.float32)
            if config.INVERT_MASK:
                mask = 1 - mask

        batch_incomplete = batch_pos * (1. - mask)

        if exclusionmask is not None:
            if config.INVERT_EXCLUSIONMASK:
                loss_mask = tf.cast(tf.less(0.5, exclusionmask[:, :, :, 0:1]),
                                    tf.float32)  #keep white parts
            else:
                loss_mask = tf.cast(tf.less(exclusionmask[:, :, :, 0:1], 0.5),
                                    tf.float32)  #keep black parts
            batch_incomplete = batch_incomplete * loss_mask
            batch_pos = batch_pos * loss_mask

        x1, x2, offset_flow = self.build_inpaint_net(
            batch_incomplete,
            mask,
            config,
            reuse=reuse,
            training=training,
            padding=config.PADDING,
            exclusionmask=exclusionmask)
        if config.PRETRAIN_COARSE_NETWORK:
            batch_predicted = x1
            logger.info('Set batch_predicted to x1.')
        else:
            batch_predicted = x2
            logger.info('Set batch_predicted to x2.')
        losses = {}
        # apply mask and complete image
        batch_complete = batch_predicted * mask + batch_incomplete * (1. -
                                                                      mask)
        if exclusionmask is not None:
            batch_complete = batch_complete * loss_mask

        l1_alpha = config.COARSE_L1_ALPHA
        # local patches
        if use_local_patch:
            local_patch_batch_pos = local_patch(batch_pos, bbox)
            local_patch_batch_predicted = local_patch(batch_predicted, bbox)
            local_patch_x1 = local_patch(x1, bbox)
            local_patch_x2 = local_patch(x2, bbox)
            local_patch_batch_complete = local_patch(batch_complete, bbox)
            local_patch_mask = local_patch(mask, bbox)

            losses['l1_loss'] = l1_alpha * tf.reduce_mean(
                tf.abs(local_patch_batch_pos - local_patch_x1) *
                spatial_discounting_mask(config) *
                (loss_mask if exclusionmask is not None else 1))
            if not config.PRETRAIN_COARSE_NETWORK:
                losses['l1_loss'] += tf.reduce_mean(
                    tf.abs(local_patch_batch_pos - local_patch_x2) *
                    spatial_discounting_mask(config) *
                    (loss_mask if exclusionmask is not None else 1))

            losses['ae_loss'] = l1_alpha * tf.reduce_mean(
                tf.abs(batch_pos - x1) * (1. - mask) *
                (loss_mask if exclusionmask is not None else 1))
            if not config.PRETRAIN_COARSE_NETWORK:
                losses['ae_loss'] += tf.reduce_mean(
                    tf.abs(batch_pos - x2) * (1. - mask) *
                    (loss_mask if exclusionmask is not None else 1))
            losses['ae_loss'] /= tf.reduce_mean(1. - mask)
        else:
            losses['l1_loss'] = l1_alpha * tf.reduce_mean(
                tf.abs(batch_pos - x1) *
                (loss_mask if exclusionmask is not None else 1))
            if not config.PRETRAIN_COARSE_NETWORK:
                losses['l1_loss'] += tf.reduce_mean(
                    tf.abs(batch_pos - x2) *
                    (loss_mask if exclusionmask is not None else 1))

        if summary:
            scalar_summary('losses/l1_loss', losses['l1_loss'])
            if use_local_patch:
                scalar_summary('losses/ae_loss', losses['ae_loss'])
            img_size = [dim for dim in batch_incomplete.shape]
            img_size[2] = 5
            border = tf.zeros(tf.TensorShape(img_size))
            viz_img = [
                batch_pos, border, batch_incomplete, border, batch_complete,
                border
            ]
            if not config.PRETRAIN_COARSE_NETWORK:
                batch_complete_coarse = x1 * mask + batch_incomplete * (1. -
                                                                        mask)
                viz_img.append(batch_complete_coarse)
                viz_img.append(border)
            if offset_flow is not None:
                scale = 2 << len(offset_flow)
                for flow in offset_flow:
                    viz_img.append(
                        resize(flow,
                               scale=scale,
                               func=tf.image.resize_nearest_neighbor))
                    viz_img.append(border)
                    scale >>= 1
            images_summary(tf.concat(viz_img, axis=2),
                           'raw_incomplete_predicted_complete',
                           config.VIZ_MAX_OUT)

        # gan
        batch_pos_neg = tf.concat([batch_pos, batch_complete], axis=0)
        # local deterministic patch
        if config.GAN_WITH_MASK:
            batch_pos_neg = tf.concat([
                batch_pos_neg,
                tf.tile(mask, [config.BATCH_SIZE * 2, 1, 1, 1])
            ],
                                      axis=3)
        # wgan with gradient penalty
        if config.GAN == 'wgan_gp':
            if not use_local_patch:
                raise Exception('wgan_gp requires global and local patch')
            local_patch_batch_pos_neg = tf.concat(
                [local_patch_batch_pos, local_patch_batch_complete], 0)
            # seperate gan
            pos_neg_local, pos_neg_global = self.build_wgan_discriminator(
                local_patch_batch_pos_neg,
                batch_pos_neg,
                training=training,
                reuse=reuse)
            pos_local, neg_local = tf.split(pos_neg_local, 2)
            pos_global, neg_global = tf.split(pos_neg_global, 2)
            # wgan loss
            g_loss_local, d_loss_local = gan_wgan_loss(pos_local,
                                                       neg_local,
                                                       name='gan/local_gan')
            g_loss_global, d_loss_global = gan_wgan_loss(pos_global,
                                                         neg_global,
                                                         name='gan/global_gan')
            losses[
                'g_loss'] = config.GLOBAL_WGAN_LOSS_ALPHA * g_loss_global + g_loss_local
            losses['d_loss'] = d_loss_global + d_loss_local
            # gp
            interpolates_local = random_interpolates(
                local_patch_batch_pos, local_patch_batch_complete)
            interpolates_global = random_interpolates(batch_pos,
                                                      batch_complete)
            dout_local, dout_global = self.build_wgan_discriminator(
                interpolates_local, interpolates_global, reuse=True)
            # apply penalty
            penalty_local = gradients_penalty(interpolates_local,
                                              dout_local,
                                              mask=local_patch_mask)
            penalty_global = gradients_penalty(interpolates_global,
                                               dout_global,
                                               mask=mask)
            losses['gp_loss'] = config.WGAN_GP_LAMBDA * (penalty_local +
                                                         penalty_global)
            losses['d_loss'] = losses['d_loss'] + losses['gp_loss']
            if summary and not config.PRETRAIN_COARSE_NETWORK:
                gradients_summary(g_loss_local,
                                  batch_predicted,
                                  name='g_loss_local')
                gradients_summary(g_loss_global,
                                  batch_predicted,
                                  name='g_loss_global')
                scalar_summary('convergence/d_loss', losses['d_loss'])
                scalar_summary('convergence/local_d_loss', d_loss_local)
                scalar_summary('convergence/global_d_loss', d_loss_global)
                scalar_summary('gan_wgan_loss/gp_loss', losses['gp_loss'])
                scalar_summary('gan_wgan_loss/gp_penalty_local', penalty_local)
                scalar_summary('gan_wgan_loss/gp_penalty_global',
                               penalty_global)
        elif config.GAN == 'sngan':
            if use_local_patch:
                raise Exception(
                    'sngan incompatible with global and local patch')
            pos_neg = self.build_sngan_discriminator(batch_pos_neg,
                                                     name='discriminator',
                                                     reuse=reuse)
            pos, neg = tf.split(pos_neg, 2)
            g_loss, d_loss = gan_hinge_loss(pos, neg)
            losses['g_loss'] = g_loss
            losses['d_loss'] = d_loss
            if summary and not config.PRETRAIN_COARSE_NETWORK:
                gradients_summary(g_loss, batch_predicted, name='g_loss')
                scalar_summary('convergence/d_loss', losses['d_loss'])
        else:
            losses['g_loss'] = 0

        if summary and not config.PRETRAIN_COARSE_NETWORK:
            # summary the magnitude of gradients from different losses w.r.t. predicted image
            gradients_summary(losses['g_loss'], batch_predicted, name='g_loss')
            gradients_summary(losses['g_loss'], x1, name='g_loss_to_x1')
            gradients_summary(losses['g_loss'], x2, name='g_loss_to_x2')
            gradients_summary(losses['l1_loss'], x1, name='l1_loss_to_x1')
            gradients_summary(losses['l1_loss'], x2, name='l1_loss_to_x2')
            if use_local_patch:
                gradients_summary(losses['ae_loss'], x1, name='ae_loss_to_x1')
                gradients_summary(losses['ae_loss'], x2, name='ae_loss_to_x2')
        if config.PRETRAIN_COARSE_NETWORK:
            losses['g_loss'] = 0
        else:
            losses['g_loss'] = config.GAN_LOSS_ALPHA * losses['g_loss']
        losses['g_loss'] += config.L1_LOSS_ALPHA * losses['l1_loss']
        logger.info('Set L1_LOSS_ALPHA to %f' % config.L1_LOSS_ALPHA)
        logger.info('Set GAN_LOSS_ALPHA to %f' % config.GAN_LOSS_ALPHA)
        if config.AE_LOSS and use_local_patch:
            losses['g_loss'] += config.AE_LOSS_ALPHA * losses['ae_loss']
            logger.info('Set AE_LOSS_ALPHA to %f' % config.AE_LOSS_ALPHA)
        g_vars = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES,
                                   'inpaint_net')
        d_vars = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES,
                                   'discriminator')
        return g_vars, d_vars, losses
    def get_loss(self, batch_data):
        config = self.config
        batch_pos = batch_data / 127.5 - 1
        bbox = random_bbox(config)
        mask = bbox2mask(bbox, batch_data.shape[0], config, self.xp)
        batch_incomplete = batch_pos * (1 - mask)
        x1, x2, offset_flow = self.inpaintnet(batch_incomplete, mask, config)
        if config.PRETRAIN_COARSE_NETWORK:
            batch_predicted = x1
        else:
            batch_predicted = x2
        losses = {}
        # apply mask and complete image
        batch_complete = batch_predicted * mask + batch_incomplete * (1 - mask)
        # local patches
        local_patch_batch_pos = local_patch(batch_pos, bbox)
        local_patch_x1 = local_patch(x1, bbox)
        local_patch_x2 = local_patch(x2, bbox)
        local_patch_batch_complete = local_patch(batch_complete, bbox)
        local_patch_mask = local_patch(mask, bbox)
        l1_alpha = config.COARSE_L1_ALPHA
        losses["l1_loss"] = l1_alpha * F.mean(
            F.absolute(local_patch_batch_pos - local_patch_x1) *
            spatial_discounting_mask(config, self.xp))
        if not config.PRETRAIN_COARSE_NETWORK:
            losses['l1_loss'] += F.mean(
                F.absolute(local_patch_batch_pos - local_patch_x2) *
                spatial_discounting_mask(config, self.xp))
        losses['ae_loss'] = l1_alpha * F.mean(
            F.absolute(batch_pos - x1) * (1. - mask))
        if not config.PRETRAIN_COARSE_NETWORK:
            losses['ae_loss'] += F.mean(
                F.absolute(batch_pos - x2) * (1. - mask))
        losses['ae_loss'] /= F.mean(1. - mask)

        # gan
        batch_pos_neg = F.concat([batch_pos, batch_complete], axis=0)
        # local deterministic patch
        local_patch_batch_pos_neg = F.concat(
            [local_patch_batch_pos, local_patch_batch_complete], 0)
        if config.GAN_WITH_MASK:
            batch_pos_neg = F.concat([batch_pos_neg, mask], axis=1)
        # wgan with gradient penalty
        if config.GAN == 'wgan_gp':
            # seperate gan
            pos_neg_local, pos_neg_global = self.discriminator(
                local_patch_batch_pos_neg, batch_pos_neg)
            pos_local, neg_local = F.split_axis(pos_neg_local, 2, axis=0)
            pos_global, neg_global = F.split_axis(pos_neg_global, 2, axis=0)
            # wgan loss
            g_loss_local, d_loss_local = gan_wgan_loss(pos_local, neg_local)
            g_loss_global, d_loss_global = gan_wgan_loss(
                pos_global, neg_global)
            losses[
                'g_loss'] = config.GLOBAL_WGAN_LOSS_ALPHA * g_loss_global + g_loss_local
            losses['d_loss'] = d_loss_global + d_loss_local
            # gp
            interpolates_local = random_interpolates(
                local_patch_batch_pos, local_patch_batch_complete)
            interpolates_global = random_interpolates(batch_pos,
                                                      batch_complete)
            dout_local, dout_global = self.discriminator(
                interpolates_local, interpolates_global)
            # apply penalty
            penalty_local = gradients_penalty(interpolates_local,
                                              dout_local,
                                              mask=local_patch_mask)
            penalty_global = gradients_penalty(interpolates_global,
                                               dout_global,
                                               mask=mask)
            losses['gp_loss'] = config.WGAN_GP_LAMBDA * (penalty_local +
                                                         penalty_global)
            losses['d_loss'] = losses['d_loss'] + losses['gp_loss']

        if config.PRETRAIN_COARSE_NETWORK:
            losses['g_loss'] = 0
        else:
            losses['g_loss'] = config.GAN_LOSS_ALPHA * losses['g_loss']
        losses['g_loss'] += config.L1_LOSS_ALPHA * losses['l1_loss']
        if config.AE_LOSS:
            losses['g_loss'] += config.AE_LOSS_ALPHA * losses['ae_loss']
        return losses