예제 #1
0
    def predict(self):
        # convert tensor to variables
        real_img = Variable(self._input_real_img, volatile=True)
        _, real_cond = self._D.forward(real_img)
        real_cond = real_cond.unsqueeze(0)
        print(real_cond.size())
        # real_cond = Variable(self._input_real_cond, volatile=True)
        desired_cond = Variable(self._input_desired_cond, volatile=True)

        # generate fake images
        fake_imgs, fake_img_mask = self._G.forward(real_img, desired_cond)
        fake_img_mask = self._do_if_necessary_saturate_mask(fake_img_mask, saturate=self._opt.do_saturate_mask)
        fake_imgs_masked = fake_img_mask * real_img + (1 - fake_img_mask) * fake_imgs

        rec_real_img_rgb, rec_real_img_mask = self._G.forward(fake_imgs_masked, real_cond)
        rec_real_img_mask = self._do_if_necessary_saturate_mask(rec_real_img_mask,
                                                                saturate=self._opt.do_saturate_mask)
        rec_real_imgs = rec_real_img_mask * fake_imgs_masked + (1 - rec_real_img_mask) * rec_real_img_rgb

        imgs = None
        data = None
        # normalize mask for better visualization
        fake_img_mask_max = fake_imgs_masked.view(fake_img_mask.size(0), -1).max(-1)[0]
        fake_img_mask_max = torch.unsqueeze(torch.unsqueeze(torch.unsqueeze(fake_img_mask_max, -1), -1), -1)
        # fake_img_mask_norm = fake_img_mask / fake_img_mask_max
        fake_img_mask_norm = fake_img_mask

        # generate images
        im_real_img = util.tensor2im(real_img.data)
        im_fake_imgs = util.tensor2im(fake_imgs.data)
        im_fake_img_mask_norm = util.tensor2maskim(fake_img_mask_norm.data)
        im_fake_imgs_masked = util.tensor2im(fake_imgs_masked.data)
        im_rec_imgs = util.tensor2im(rec_real_img_rgb.data)
        im_rec_img_mask_norm = util.tensor2maskim(rec_real_img_mask.data)
        im_rec_imgs_masked = util.tensor2im(rec_real_imgs.data)
        im_concat_img = np.concatenate([im_real_img, im_fake_imgs_masked, im_fake_img_mask_norm, im_fake_imgs,
                                        im_rec_imgs, im_rec_img_mask_norm, im_rec_imgs_masked],
                                       1)

        imgs = OrderedDict([('real_img', im_real_img),
                            ('fake_imgs', im_fake_imgs),
                            ('fake_img_mask', im_fake_img_mask_norm),
                            ('fake_imgs_masked', im_fake_imgs_masked),
                            ('concat', im_concat_img)
                            ])

        data = OrderedDict([('real_cond', real_cond.data[0,...].cpu().numpy().astype('str'))
                            ])


        return imgs, data
예제 #2
0
    def _forward_G(self, keep_data_for_visuals):
        # generate fake images
        fake_imgs, fake_img_mask = self._G.forward(self._real_img, self._desired_cond)
        fake_img_mask = self._do_if_necessary_saturate_mask(fake_img_mask, saturate=self._opt.do_saturate_mask)
        fake_imgs_masked = fake_img_mask * self._real_img + (1 - fake_img_mask) * fake_imgs

        # D(G(Ic1, c2)*M) masked
        d_fake_desired_img_masked_prob, d_fake_desired_img_masked_cond = self._D.forward(fake_imgs_masked)
        self._loss_g_masked_fake = self._compute_loss_D(d_fake_desired_img_masked_prob, True) * self._opt.lambda_D_prob
        self._loss_g_masked_cond = self._criterion_D_cond(d_fake_desired_img_masked_cond, self._desired_cond) / self._B * self._opt.lambda_D_cond

        # G(G(Ic1,c2), c1)
        rec_real_img_rgb, rec_real_img_mask = self._G.forward(fake_imgs_masked, self._real_cond)
        rec_real_img_mask = self._do_if_necessary_saturate_mask(rec_real_img_mask, saturate=self._opt.do_saturate_mask)
        rec_real_imgs = rec_real_img_mask * fake_imgs_masked + (1 - rec_real_img_mask) * rec_real_img_rgb

        # l_cyc(G(G(Ic1,c2), c1)*M)
        self._loss_g_cyc = self._criterion_cycle(rec_real_imgs, self._real_img) * self._opt.lambda_cyc

        # loss mask
        self._loss_g_mask_1 = torch.mean(fake_img_mask) * self._opt.lambda_mask
        self._loss_g_mask_2 = torch.mean(rec_real_img_mask) * self._opt.lambda_mask
        self._loss_g_mask_1_smooth = self._compute_loss_smooth(fake_img_mask) * self._opt.lambda_mask_smooth
        self._loss_g_mask_2_smooth = self._compute_loss_smooth(rec_real_img_mask) * self._opt.lambda_mask_smooth

        # keep data for visualization
        if keep_data_for_visuals:
            self._vis_real_img = util.tensor2im(self._input_real_img)
            self._vis_fake_img_unmasked = util.tensor2im(fake_imgs.data)
            self._vis_fake_img = util.tensor2im(fake_imgs_masked.data)
            self._vis_fake_img_mask = util.tensor2maskim(fake_img_mask.data)
            self._vis_real_cond = self._input_real_cond.cpu()[0, ...].numpy()
            self._vis_desired_cond = self._input_desired_cond.cpu()[0, ...].numpy()
            self._vis_batch_real_img = util.tensor2im(self._input_real_img, idx=-1)
            self._vis_batch_fake_img_mask = util.tensor2maskim(fake_img_mask.data, idx=-1)
            self._vis_batch_fake_img = util.tensor2im(fake_imgs_masked.data, idx=-1)
            self._vis_rec_img_unmasked = util.tensor2im(rec_real_img_rgb.data)
            self._vis_rec_real_img = util.tensor2im(rec_real_imgs.data)
            self._vis_rec_real_img_mask = util.tensor2maskim(rec_real_img_mask.data)
            self._vis_batch_rec_real_img = util.tensor2im(rec_real_imgs.data, idx=-1)

        # combine losses
        return self._loss_g_masked_fake + self._loss_g_masked_cond + \
               self._loss_g_cyc + \
               self._loss_g_mask_1 + self._loss_g_mask_2 + \
               self._loss_g_mask_1_smooth + self._loss_g_mask_2_smooth
    def visual_imgs(self, fake_bg, fake_src_imgs, fake_tsf_imgs, fake_masks):
        ids = fake_masks.shape[0] // 2
        self._vis_input = util.tensor2im(self._real_src)
        self._vis_tsf = util.tensor2im(self._input_G_tsf[0, 0:3])
        self._vis_fake_bg = util.tensor2im(fake_bg)
        self._vis_fake_src = util.tensor2im(fake_src_imgs)
        self._vis_fake_tsf = util.tensor2im(fake_tsf_imgs)
        self._vis_mask = util.tensor2maskim(fake_masks[ids])

        self._vis_batch_real = util.tensor2im(self._real_tsf, idx=-1)
        self._vis_batch_fake = util.tensor2im(fake_tsf_imgs, idx=-1)
    def visual_imgs(self, fake_bg, fake_imgs, fake_color, fake_masks):
        self._vis_real_img = util.tensor2im(self._input_real_imgs)

        ids = fake_imgs.shape[0] // 2
        self._vis_tsf = util.tensor2im(self._input_tsf)
        self._vis_fake_bg = util.tensor2im(fake_bg.detach())
        self._vis_fake_color = util.tensor2im(fake_color.detach())
        self._vis_fake_img = util.tensor2im(fake_imgs[ids].detach())
        self._vis_fake_mask = util.tensor2maskim(fake_masks[ids].detach())

        self._vis_batch_real_img = util.tensor2im(self._input_real_imgs,
                                                  idx=-1)
        self._vis_batch_fake_img = util.tensor2im(fake_imgs.detach(), idx=-1)
예제 #5
0
    def forward(self, keep_data_for_visuals=False):

        if not self._is_train:

            im_occ = self._input_img_occ

            fake_img, fake_img_mask = self._G.forward(im_occ)
            fake_img_synthesis = fake_img_mask * im_occ + (
                1 - fake_img_mask) * fake_img

            if keep_data_for_visuals:

                self._vis_batch_occ_img = util.tensor2im(im_occ, idx=-1)
                self._vis_batch_fake_img = util.tensor2im(fake_img.data,
                                                          idx=-1)
                self._vis_batch_fake_img_mask = util.tensor2maskim(
                    fake_img_mask.data, idx=-1)
                self._vis_batch_fake_synthesis = util.tensor2im(
                    fake_img_synthesis.data, idx=-1)
                self._vis_batch_none_occ_img = util.tensor2im(
                    self._input_img_none_occ, idx=-1)
예제 #6
0
    def forward_one(self, rest_img: torch.Tensor, pose_vec: torch.Tensor):
        if self._is_train:
            print("only for test")
            return None

        # set input
        self._input_rest_img.resize_(rest_img.size()).copy_(rest_img)
        self._input_pose_vec.resize_(pose_vec.size()).copy_(pose_vec)
        if len(self._gpu_ids) > 0:
            self._input_rest_img = self._input_rest_img.cuda(self._gpu_ids[0])
            self._input_pose_vec = self._input_pose_vec.cuda(self._gpu_ids[0])

        with torch.no_grad():
            # convert tensor to variables
            rest_img = Variable(self._input_rest_img)
            pose = Variable(self._input_pose_vec)

            # generate fake images
            fake_imgs, fake_img_mask = self._G.forward(rest_img, pose)
            fake_img_mask = self._do_if_necessary_saturate_mask(
                fake_img_mask, saturate=self._opt.do_saturate_mask)
            fake_imgs_masked = fake_img_mask * rest_img + (
                1 - fake_img_mask) * fake_imgs

        im_fake_imgs = util.tensor2im(fake_imgs.data)
        im_fake_img_mask_norm = util.tensor2maskim(fake_img_mask.data)
        im_fake_imgs_masked = util.tensor2im(fake_imgs_masked.data)
        im_concat_img = np.concatenate(
            [im_fake_imgs_masked, im_fake_img_mask_norm, im_fake_imgs], 1)
        imgs = {
            "result": im_fake_imgs_masked,
            "mask": im_fake_img_mask_norm,
            "change": im_fake_imgs,
            "concat": im_concat_img,
        }
        return imgs
예제 #7
0
    def forward(self, keep_data_for_visuals=False, return_estimates=False):
        if not self._is_train:
            # convert tensor to variables
            with torch.no_grad():
                real_img = Variable(self._input_real_img)
                real_cond = Variable(self._input_real_cond)
                desired_cond = Variable(self._input_desired_cond)
            # real_img = Variable(self._input_real_img, volatile=True)
            # real_cond = Variable(self._input_real_cond, volatile=True)
            # desired_cond = Variable(self._input_desired_cond, volatile=True)

            # generate fake images
            fake_imgs, fake_img_mask = self._G.forward(real_img, desired_cond)
            fake_img_mask = self._do_if_necessary_saturate_mask(
                fake_img_mask, saturate=self._opt.do_saturate_mask)
            fake_imgs_masked = fake_img_mask * real_img + (
                1 - fake_img_mask) * fake_imgs

            rec_real_img_rgb, rec_real_img_mask = self._G.forward(
                fake_imgs_masked, real_cond)
            rec_real_img_mask = self._do_if_necessary_saturate_mask(
                rec_real_img_mask, saturate=self._opt.do_saturate_mask)
            rec_real_imgs = rec_real_img_mask * fake_imgs_masked + (
                1 - rec_real_img_mask) * rec_real_img_rgb

            imgs = None
            data = None
            if return_estimates:
                # normalize mask for better visualization
                fake_img_mask_max = fake_imgs_masked.view(
                    fake_img_mask.size(0), -1).max(-1)[0]
                fake_img_mask_max = torch.unsqueeze(
                    torch.unsqueeze(torch.unsqueeze(fake_img_mask_max, -1),
                                    -1), -1)
                # fake_img_mask_norm = fake_img_mask / fake_img_mask_max
                fake_img_mask_norm = fake_img_mask

                # generate images
                im_real_img = util.tensor2im(real_img.data)
                im_fake_imgs = util.tensor2im(fake_imgs.data)
                im_fake_img_mask_norm = util.tensor2maskim(
                    fake_img_mask_norm.data)
                im_fake_imgs_masked = util.tensor2im(fake_imgs_masked.data)
                im_rec_imgs = util.tensor2im(rec_real_img_rgb.data)
                im_rec_img_mask_norm = util.tensor2maskim(
                    rec_real_img_mask.data)
                im_rec_imgs_masked = util.tensor2im(rec_real_imgs.data)
                im_concat_img = np.concatenate([
                    im_real_img, im_fake_imgs_masked, im_fake_img_mask_norm,
                    im_fake_imgs, im_rec_imgs, im_rec_img_mask_norm,
                    im_rec_imgs_masked
                ], 1)

                im_real_img_batch = util.tensor2im(real_img.data,
                                                   idx=-1,
                                                   nrows=1)
                im_fake_imgs_batch = util.tensor2im(fake_imgs.data,
                                                    idx=-1,
                                                    nrows=1)
                im_fake_img_mask_norm_batch = util.tensor2maskim(
                    fake_img_mask_norm.data, idx=-1, nrows=1)
                im_fake_imgs_masked_batch = util.tensor2im(
                    fake_imgs_masked.data, idx=-1, nrows=1)
                im_concat_img_batch = np.concatenate([
                    im_real_img_batch, im_fake_imgs_masked_batch,
                    im_fake_img_mask_norm_batch, im_fake_imgs_batch
                ], 1)

                imgs = OrderedDict([
                    ('real_img', im_real_img),
                    ('fake_imgs', im_fake_imgs),
                    ('fake_img_mask', im_fake_img_mask_norm),
                    ('fake_imgs_masked', im_fake_imgs_masked),
                    ('concat', im_concat_img),
                    ('real_img_batch', im_real_img_batch),
                    ('fake_imgs_batch', im_fake_imgs_batch),
                    ('fake_img_mask_batch', im_fake_img_mask_norm_batch),
                    ('fake_imgs_masked_batch', im_fake_imgs_masked_batch),
                    ('concat_batch', im_concat_img_batch),
                ])

                data = OrderedDict([
                    ('real_path', self._input_real_img_path),
                    ('desired_cond',
                     desired_cond.data[0, ...].cpu().numpy().astype('str'))
                ])

            # keep data for visualization
            if keep_data_for_visuals:
                self._vis_real_img = util.tensor2im(self._input_real_img)
                self._vis_fake_img_unmasked = util.tensor2im(fake_imgs.data)
                self._vis_fake_img = util.tensor2im(fake_imgs_masked.data)
                self._vis_fake_img_mask = util.tensor2maskim(
                    fake_img_mask.data)
                self._vis_real_cond = self._input_real_cond.cpu()[0,
                                                                  ...].numpy()
                self._vis_desired_cond = self._input_desired_cond.cpu()[
                    0, ...].numpy()
                self._vis_batch_real_img = util.tensor2im(self._input_real_img,
                                                          idx=-1)
                self._vis_batch_fake_img_mask = util.tensor2maskim(
                    fake_img_mask.data, idx=-1)
                self._vis_batch_fake_img = util.tensor2im(
                    fake_imgs_masked.data, idx=-1)

            return imgs, data
예제 #8
0
    def _forward_G(self, keep_data_for_visuals, has_GT, has_attr):

        fake_img, fake_img_mask = self._G.forward(self._img_occ)
        fake_img_synthesis = fake_img_mask * self._img_occ + (
            1 - fake_img_mask) * fake_img

        if has_GT == True:

            fake_img_synthesis_feature = self._vgg(fake_img_synthesis)
            fake_img_feature = self._vgg(fake_img)
            gt_img_feature = self._vgg(self._img_none_occ)

            style = 0
            perceptual = 0

            for i in range(3):

                style += self._compute_loss_l1(
                    self._compute_loss_gram_matrix(fake_img_feature[i]),
                    self._compute_loss_gram_matrix(gt_img_feature[i]))
                style += self._compute_loss_l1(
                    self._compute_loss_gram_matrix(
                        fake_img_synthesis_feature[i]),
                    self._compute_loss_gram_matrix(gt_img_feature[i]))

                perceptual += self._compute_loss_l1(fake_img_feature[i],
                                                    gt_img_feature[i])
                perceptual += self._compute_loss_l1(
                    fake_img_synthesis_feature[i], gt_img_feature[i])

            self._loss_g_style = style * self._opt.lambda_g_style
            self._loss_g_perceptual = perceptual * self._opt.lambda_g_perceptual

            target = (1 - fake_img_mask) * self._img_none_occ
            target = target.detach()
            self._loss_g_hole = self._compute_loss_l1(
                (1 - fake_img_mask) * fake_img,
                target) * self._opt.lambda_g_hole

            target = fake_img_mask * self._img_none_occ
            target = target.detach()
            self._loss_g_vaild = self._compute_loss_l1(
                fake_img_mask * fake_img, target) * self._opt.lambda_g_valid

        # self._loss_g_mask_hash = -0.5 * torch.abs(fake_img_mask - 0.5).mean() * self._opt.lambda_g_hash

        d_fake_img_synthesis_prob, d_fake_img_attr = self._D.forward(
            fake_img_synthesis)

        if has_attr == True:
            self._loss_g_attr = self._compute_loss_attr(
                d_fake_img_attr,
                self._occ_attr) / self._B * self._opt.lambda_D_attr

        self._loss_g_synthesis_fake = self._compute_loss_D(
            d_fake_img_synthesis_prob, True) * self._opt.lambda_D_prob
        self._loss_g_mask = -torch.mean(fake_img_mask).pow(
            2) * self._opt.lambda_mask
        self._loss_g_mask_smooth = self._compute_loss_smooth(
            fake_img_mask) * self._opt.lambda_mask_smooth
        self._loss_g_synth_smooth = self._compute_loss_smooth(
            fake_img_synthesis) * self._opt.lambda_g_syhth_smooth

        if keep_data_for_visuals:

            self._vis_batch_occ_img = util.tensor2im(self._input_img_occ,
                                                     idx=-1)
            self._vis_batch_fake_img = util.tensor2im(fake_img.data, idx=-1)
            self._vis_batch_fake_img_mask = util.tensor2maskim(
                fake_img_mask.data, idx=-1)
            self._vis_batch_fake_synthesis = util.tensor2im(
                fake_img_synthesis.data, idx=-1)
            self._vis_batch_none_occ_img = util.tensor2im(
                self._input_img_none_occ, idx=-1)

        if has_GT == True and has_attr == True:
            return self._loss_g_synthesis_fake + self._loss_g_mask + \
                self._loss_g_mask_smooth + self._loss_g_synth_smooth +\
                self._loss_g_vaild + self._loss_g_hole + \
                self._loss_g_perceptual + self._loss_g_style + \
                self._loss_g_attr  # + self._loss_g_mask_hash + \

        elif has_GT == False and has_attr == True:
            return self._loss_g_synthesis_fake + self._loss_g_mask + \
                    self._loss_g_mask_smooth + self._loss_g_synth_smooth +\
                    self._loss_g_attr # + self._loss_g_mask_hash

        elif has_GT == False and has_attr == False:
            return self._loss_g_synthesis_fake + self._loss_g_mask + \
                    self._loss_g_mask_smooth + self._loss_g_synth_smooth
            #+ self._loss_g_mask_hash
        else:
            raise NotImplementedError(
                'Not existing has_GT = False and has_attr = True')
            return None
예제 #9
0
    def forward(self, keep_data_for_visuals=False, return_estimates=False):
        if self._is_train:
            print("only for test")
            return

        with torch.no_grad():
            # convert tensor to variables
            rest_img = Variable(self._input_rest_img)
            pose = Variable(self._input_pose_vec)
            expressive_img = Variable(self._input_expressive_img)

            # generate fake images
            fake_imgs, fake_img_mask = self._G.forward(rest_img, pose)
            fake_img_mask = self._do_if_necessary_saturate_mask(
                fake_img_mask, saturate=self._opt.do_saturate_mask)
            fake_imgs_masked = fake_img_mask * rest_img + (
                1 - fake_img_mask) * fake_imgs

        # rec_real_img_rgb, rec_real_img_mask = self._G.forward(fake_imgs_masked, real_cond)
        # rec_real_img_mask = self._do_if_necessary_saturate_mask(rec_real_img_mask, saturate=self._opt.do_saturate_mask)
        # rec_real_imgs = rec_real_img_mask * fake_imgs_masked + (1 - rec_real_img_mask) * rec_real_img_rgb

        imgs = None
        data = None
        if return_estimates:
            # normalize mask for better visualization
            # fake_img_mask_max = fake_imgs_masked.view(fake_img_mask.size(0), -1).max(-1)[0]
            # fake_img_mask_max = torch.unsqueeze(torch.unsqueeze(torch.unsqueeze(fake_img_mask_max, -1), -1), -1)
            # fake_img_mask_norm = fake_img_mask / fake_img_mask_max
            fake_img_mask_norm = fake_img_mask

            # generate images
            im_real_img = util.tensor2im(rest_img.data)
            im_expressive_img = util.tensor2im(expressive_img.data)
            im_fake_imgs = util.tensor2im(fake_imgs.data)
            im_fake_img_mask_norm = util.tensor2maskim(fake_img_mask_norm.data)
            im_fake_imgs_masked = util.tensor2im(fake_imgs_masked.data)
            # im_rec_imgs = util.tensor2im(rec_real_img_rgb.data)
            # im_rec_img_mask_norm = util.tensor2maskim(rec_real_img_mask.data)
            # im_rec_imgs_masked = util.tensor2im(rec_real_imgs.data)
            im_concat_img = np.concatenate(
                [
                    im_real_img,
                    im_expressive_img,
                    im_fake_imgs_masked,
                    im_fake_img_mask_norm,
                    im_fake_imgs,
                    # im_rec_imgs, im_rec_img_mask_norm, im_rec_imgs_masked
                ],
                1)

            im_real_img_batch = util.tensor2im(rest_img.data, idx=-1, nrows=1)
            im_fake_imgs_batch = util.tensor2im(fake_imgs.data,
                                                idx=-1,
                                                nrows=1)
            im_fake_img_mask_norm_batch = util.tensor2maskim(
                fake_img_mask_norm.data, idx=-1, nrows=1)
            im_fake_imgs_masked_batch = util.tensor2im(fake_imgs_masked.data,
                                                       idx=-1,
                                                       nrows=1)
            im_concat_img_batch = np.concatenate([
                im_real_img_batch, im_fake_imgs_masked_batch,
                im_fake_img_mask_norm_batch, im_fake_imgs_batch
            ], 1)

            imgs = OrderedDict([
                ('real_img', im_real_img),
                ('fake_imgs', im_fake_imgs),
                ('fake_img_mask', im_fake_img_mask_norm),
                ('fake_imgs_masked', im_fake_imgs_masked),
                ('concat', im_concat_img),
                ('real_img_batch', im_real_img_batch),
                ('fake_imgs_batch', im_fake_imgs_batch),
                ('fake_img_mask_batch', im_fake_img_mask_norm_batch),
                ('fake_imgs_masked_batch', im_fake_imgs_masked_batch),
                ('concat_batch', im_concat_img_batch),
            ])

            data = OrderedDict([
                ('rest_name', self._input_rest_name),
                ('pose', pose.data[0, ...].cpu().numpy().astype('str'))
            ])

        # keep data for visualization
        if keep_data_for_visuals:
            # self._vis_real_img = util.tensor2im(self._input_rest_img)
            # self._vis_fake_img_unmasked = util.tensor2im(fake_imgs.data)
            # self._vis_fake_img = util.tensor2im(fake_imgs_masked.data)
            # self._vis_fake_img_mask = util.tensor2maskim(fake_img_mask.data)
            # self._vis_real_cond = self._input_pose_vec.cpu()[0, ...].numpy()
            # self._vis_desired_cond = self._input_desired_cond.cpu()[0, ...].numpy()
            # self._vis_batch_real_img = util.tensor2im(self._input_real_img, idx=-1)
            # self._vis_batch_fake_img_mask = util.tensor2maskim(fake_img_mask.data, idx=-1)
            # self._vis_batch_fake_img = util.tensor2im(fake_imgs_masked.data, idx=-1)
            pass

        return imgs, data