コード例 #1
0
ファイル: celebaHQ.py プロジェクト: miwaliu/IRAE_pytorch
    def __getitem__(self, idx):
        filename = self.data[idx]
        img = Image.open(os.path.join(self.root, self.img_folder,
                                      filename))  # loads in RGB mode
        if self.transform is not None:
            img = self.transform(img)

        if self.args.denoise:
            corrupted = addNoise(img,
                                 sigma=self.args.noise_level,
                                 mode=self.args.noise_mode)  # AWGN
            return img, corrupted

        elif self.args.inpainting:
            if self.args.inpainting_mode == 'center':  # center hole inpainting
                bboxes = random_bbox(img, args=self.args)
                corrupted, mask = mask_image(img, bboxes)
                return img, corrupted, mask
            if self.args.inpainting_mode == 'irregular':  # irregular hole inpainting
                if self.args.archi == 'PartialConv':
                    while True:
                        mask = Image.open(self.mask_paths[random.randint(
                            0, self.N_mask - 1)])
                        mask = self.mask_transform(mask.convert('RGB'))
                        percent = 100 * float((mask == 0).sum()) / float(
                            torch.ones_like(mask).sum())
                        if percent >= 10 and percent < 50:
                            break
                    corrupted = img * mask  # + (1. - mask)
                    # torchvision.utils.save_image(corrupted, 'try_irregular.jpg')
                    # assert 0
                    return img, corrupted, mask
                else:
                    while True:
                        mask = Image.open(self.mask_paths[random.randint(
                            0, self.N_mask - 1)])
                        mask = self.mask_transform(mask.convert('RGB'))
                        mask = mask[0]
                        mask = torch.unsqueeze(mask, 0)
                        # mask = torch.unsqueeze(mask, 1)
                        mask = mask.byte()
                        percent = 100 * float((mask == 0).sum()) / float(
                            torch.ones_like(mask).sum())
                        if percent >= 10 and percent < 50:
                            break
                    corrupted = img * mask  # + (1.-mask)
                    # torchvision.utils.save_image(corrupted, 'try_irregular.jpg')
                    # assert 0
                    return img, corrupted, mask
コード例 #2
0
ファイル: bird.py プロジェクト: miwaliu/IRAE_pytorch
    def __getitem__(self, idx):
        file_name = self.file_list[self.data_idx[idx] - 1]
        img = Image.open(os.path.join(self.root, self.img_folder,
                                      file_name))  # loads in RGB mode
        if self.transform is not None:
            img = self.transform(img)
        if self.args.denoise:
            corrupted = addNoise(img,
                                 sigma=self.args.noise_level,
                                 mode=self.args.noise_mode)  # AWGN
        elif self.args.inpainting:
            bboxes = random_bbox(img)
            corrupted, mask = mask_image(img, bboxes)  # hole_inpainting

        return img, corrupted
コード例 #3
0
ファイル: flower.py プロジェクト: miwaliu/IRAE_pytorch
    def __getitem__(self, idx):
        file_name = 'image_' + '%05d' % (
            self.data_idx[idx]
        ) + '.jpg'  # if don't added, there might be some bug, but what we have run haven't added this line
        img = Image.open(os.path.join(self.root, self.img_folder,
                                      file_name))  # loads in RGB mode
        if self.transform is not None:
            img = self.transform(img)

        if self.args.denoise:
            corrupted = addNoise(img,
                                 sigma=self.args.noise_level,
                                 mode=self.args.noise_mode)  # AWGN
        elif self.args.inpainting:
            bboxes = random_bbox(img)
            corrupted, mask = mask_image(img, bboxes)  # hole_inpainting

        return img, corrupted
コード例 #4
0
ファイル: model_in.py プロジェクト: fengjiran/pycharmprojs
    def build_infer_graph(self, batch_data, cfg, bbox=None, name='val'):
        cfg['max_delta_height'] = 0
        cfg['max_delta_width'] = 0

        if bbox is None:
            bbox = random_bbox(cfg)
        mask = bbox2mask(bbox, cfg)

        batch_pos = batch_data
        batch_incomplete = batch_pos * (1. - mask)
        ones_x = tf.ones_like(batch_incomplete)[:, :, :, 0:1]
        coarse_network_input = tf.concat([batch_incomplete, ones_x, mask],
                                         axis=3)
        # coarse_network_input = tf.concat([batch_incomplete, ones_x, ones_x * mask], axis=3)

        # inpaint
        coarse_output = self.coarse_network(coarse_network_input, reuse=True)
        batch_complete_coarse = coarse_output * mask + batch_incomplete * (
            1. - mask)

        # refine_network_input = tf.concat([batch_complete_coarse, ones_x, ones_x * mask], axis=3)
        refine_network_input = tf.concat([batch_complete_coarse, ones_x, mask],
                                         axis=3)
        refine_output = self.refine_network(refine_network_input, reuse=True)

        # apply mask and reconstruct
        # batch_complete = batch_predicted * mask + batch_incomplete * (1. - mask)
        batch_complete_coarse = coarse_output * mask + batch_incomplete * (
            1. - mask)
        batch_complete_refine = refine_output * mask + batch_incomplete * (
            1. - mask)

        # global image visualization
        visual_img = [
            batch_pos, batch_incomplete, batch_complete_coarse,
            batch_complete_refine
        ]
        images_summary(tf.concat(visual_img, axis=2),
                       name + '_raw_incomplete_coarse_refine', 10)

        return (batch_complete_coarse, batch_complete_refine)
コード例 #5
0
ファイル: model_in.py プロジェクト: fengjiran/pycharmprojs
    def build_graph_with_losses(self,
                                batch_data,
                                cfg,
                                summary=True,
                                reuse=None):
        # batch_pos = batch_data / 127.5 - 1
        batch_pos = batch_data
        bbox = random_bbox(cfg)
        mask = bbox2mask(bbox, cfg)

        batch_incomplete = batch_pos * (1. - mask)
        ones_x = tf.ones_like(batch_incomplete)[:, :, :, 0:1]
        coarse_network_input = tf.concat([batch_incomplete, ones_x, mask],
                                         axis=3)
        # coarse_network_input = tf.concat([batch_incomplete, ones_x, ones_x * mask], axis=3)

        coarse_output = self.coarse_network(coarse_network_input, reuse)
        batch_complete_coarse = coarse_output * mask + batch_pos * (1. - mask)

        # refine_network_input = tf.concat([batch_complete_coarse, ones_x, ones_x * mask], axis=3)
        refine_network_input = tf.concat([batch_complete_coarse, ones_x, mask],
                                         axis=3)
        refine_output = self.refine_network(refine_network_input, reuse)
        batch_complete_refine = refine_output * mask + batch_pos * (1. - mask)

        losses = {}

        # local patches
        local_patch_pos = local_patch(batch_pos, bbox)
        local_patch_coarse = local_patch(coarse_output, bbox)
        local_patch_refine = local_patch(refine_output, bbox)
        local_patch_mask = local_patch(mask, bbox)

        l1_alpha = cfg['coarse_l1_alpha']
        losses['coarse_l1_loss'] = l1_alpha * tf.reduce_mean(
            tf.abs(local_patch_pos - local_patch_coarse) *
            spatial_discounting_mask(cfg))
        losses['coarse_ae_loss'] = l1_alpha * tf.reduce_mean(
            tf.abs(batch_pos - coarse_output) * (1. - mask))

        losses['refine_l1_loss'] = losses['coarse_l1_loss'] + \
            tf.reduce_mean(tf.abs(local_patch_pos - local_patch_refine) *
                           spatial_discounting_mask(cfg))
        losses['refine_ae_loss'] = losses['coarse_ae_loss'] + \
            tf.reduce_mean(tf.abs(batch_pos - refine_output) * (1. - mask))

        losses['coarse_ae_loss'] /= tf.reduce_mean(1. - mask)
        losses['refine_ae_loss'] /= tf.reduce_mean(1. - mask)

        # wgan
        # global discriminator patch
        batch_pos_neg = tf.concat([batch_pos, batch_complete_refine], axis=0)

        # local discriminator patch
        local_patch_pos_neg = tf.concat([local_patch_pos, local_patch_refine],
                                        axis=0)

        # wgan with gradient penalty
        pos_neg_global, pos_neg_local = self.build_wgan_discriminator(
            batch_pos_neg, local_patch_pos_neg, reuse)

        pos_global, neg_global = tf.split(pos_neg_global, 2)
        pos_local, neg_local = tf.split(pos_neg_local, 2)

        # wgan loss
        g_loss_global, d_loss_global = gan_wgan_loss(pos_global, neg_global)
        g_loss_local, d_loss_local = gan_wgan_loss(pos_local, neg_local)

        losses['refine_g_loss'] = cfg['global_wgan_loss_alpha'] * g_loss_global + \
            cfg['local_wgan_loss_alpha'] * g_loss_local

        losses['refine_d_loss_global'] = d_loss_global
        losses['refine_d_loss_local'] = d_loss_local
        losses['refine_d_loss'] = losses[
            'refine_d_loss_global'] * 1.4 + losses['refine_d_loss_local'] * 1.4

        # gradient penalty
        interpolates_global = random_interpolates(batch_pos,
                                                  batch_complete_refine)
        interpolates_local = random_interpolates(local_patch_pos,
                                                 local_patch_refine)
        dout_global, dout_local = self.build_wgan_discriminator(
            interpolates_global, interpolates_local, reuse=True)

        # apply penalty
        # penalty_global = gradient_penalty(interpolates_global, dout_global, mask=mask, norm=750.)
        # penalty_local = gradient_penalty(interpolates_local, dout_local, mask=local_patch_mask, norm=750.)

        # lipschitz penalty
        penalty_global = lipschitz_penalty(interpolates_global, dout_global)
        penalty_local = lipschitz_penalty(interpolates_local, dout_local)

        losses['gp_loss'] = cfg['wgan_gp_lambda'] * (penalty_global +
                                                     penalty_local)
        losses['refine_d_loss'] += losses['gp_loss']

        g_vars_coarse = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES,
                                          'coarse')
        g_vars_refine = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES,
                                          'refine')
        g_vars = g_vars_coarse + g_vars_refine
        d_vars = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES,
                                   'wgan_discriminator')

        if summary:
            # stage1
            tf.summary.scalar(
                'rec_loss/coarse_rec_loss',
                losses['coarse_l1_loss'] + losses['coarse_ae_loss'])
            # tf.summary.scalar('rec_loss/coarse_l1_loss', losses['coarse_l1_loss'])
            # tf.summary.scalar('rec_loss/coarse_ae_loss', losses['coarse_ae_loss'])
            tf.summary.scalar(
                'rec_loss/refine_rec_loss',
                losses['refine_l1_loss'] + losses['refine_ae_loss'])
            # tf.summary.scalar('rec_loss/refine_l1_loss', losses['refine_l1_loss'])
            # tf.summary.scalar('rec_loss/refine_ae_loss', losses['refine_ae_loss'])

            visual_img = [
                batch_pos, batch_incomplete, batch_complete_coarse,
                batch_complete_refine
            ]
            visual_img = tf.concat(visual_img, axis=2)
            images_summary(visual_img, 'raw_incomplete_coarse_refine', 4)

            # stage2
            gradients_summary(g_loss_global,
                              refine_output,
                              name='g_loss_global')
            gradients_summary(g_loss_local, refine_output, name='g_loss_local')

            tf.summary.scalar('convergence/refine_d_loss',
                              losses['refine_d_loss'])
            # tf.summary.scalar('convergence/refine_g_loss', losses['refine_g_loss'])
            tf.summary.scalar('convergence/local_d_loss', d_loss_local)
            tf.summary.scalar('convergence/global_d_loss', d_loss_global)

            tf.summary.scalar('gradient_penalty/gp_loss', losses['gp_loss'])
            tf.summary.scalar('gradient_penalty/gp_penalty_local',
                              penalty_local)
            tf.summary.scalar('gradient_penalty/gp_penalty_global',
                              penalty_global)

            # summary the magnitude of gradients from different losses w.r.t. predicted image
            # gradients_summary(losses['g_loss'], refine_output, name='g_loss')
            gradients_summary(losses['coarse_l1_loss'] +
                              losses['coarse_ae_loss'],
                              coarse_output,
                              name='rec_loss_grad_to_coarse')
            gradients_summary(losses['refine_l1_loss'] +
                              losses['refine_ae_loss'] +
                              losses['refine_g_loss'],
                              refine_output,
                              name='rec_loss_grad_to_refine')
            gradients_summary(losses['coarse_l1_loss'],
                              coarse_output,
                              name='l1_loss_grad_to_coarse')
            gradients_summary(losses['refine_l1_loss'],
                              refine_output,
                              name='l1_loss_grad_to_refine')
            gradients_summary(losses['coarse_ae_loss'],
                              coarse_output,
                              name='ae_loss_grad_to_coarse')
            gradients_summary(losses['refine_ae_loss'],
                              refine_output,
                              name='ae_loss_grad_to_refine')

        return g_vars, g_vars_coarse, d_vars, losses