Beispiel #1
0
 def maybe_save_fake_data(self, real_data, condition, landmarks):
     if self.global_step >= self.next_image_save_point:
         self.next_image_save_point = self.global_step + self.num_ims_per_save_image
         self.generator.eval()
         with torch.no_grad():
             fake_data_sample = denormalize_img(
                 self.generator(condition, landmarks).data)
         self.logger.save_images("fakes", fake_data_sample[:64])
         # Save input images
         to_save = denormalize_img(real_data)
         self.logger.save_images("reals", to_save[:64], log_to_writer=False)
         to_save = denormalize_img(condition[:64, :3])
         self.logger.save_images("condition", to_save, log_to_writer=False)
Beispiel #2
0
    def save_transition_image(self, before):
        self.dataloader_val.update_next_transition_variable(
            self.transition_variable)
        real_image, condition, landmark = next(iter(self.dataloader_val))
        assert real_image.shape[0] >= 8
        real_data = real_image[:8]
        condition = condition[:8]
        landmark = landmark[:8]
        fake_data = self.generator(condition, landmark, self.static_z[:8])
        d_out_real = self.discriminator(real_data, condition, landmark)
        d_out_fake = self.discriminator(fake_data, condition, landmark)
        if before:
            self.d_out_real_before = d_out_real
            self.d_out_fake_before = d_out_fake
        fake_data = denormalize_img(fake_data.detach())
        real_data = denormalize_img(real_data)
        condition = denormalize_img(condition)

        to_save = torch.cat((real_data, condition, fake_data))
        tag = "before" if before else "after"
        torch.save(to_save, f".debug/{tag}.torch")
        imsize = self.current_imsize if before else self.current_imsize // 2
        imname = "transition/{}_{}_".format(tag, imsize)
        self.logger.save_images(imname, to_save, log_to_writer=False)

        if not before:
            im_before = torch.cat(
                [x for x in torch.load(".debug/before.torch")], dim=2)[None]
            im_after = torch.cat([x for x in torch.load(".debug/after.torch")],
                                 dim=2)[None]
            im_before = torch.nn.functional.interpolate(im_before,
                                                        scale_factor=2)

            diff = abs(im_after - im_before)
            diff = diff / diff.max()

            to_save = torch.cat((im_before[0], im_after[0], diff[0]), dim=1)
            self.logger.save_images(
                f"transition/to_imsize{self.current_imsize}_",
                to_save,
                log_to_writer=False)

            diff_real = (d_out_real - self.d_out_real_before).abs().sum()
            diff_fake = (d_out_fake - self.d_out_fake_before).abs().sum()
            self.logger.log_variable("transition/discriminator_diff_real",
                                     diff_real)
            self.logger.log_variable("transition/discriminator_diff_fake",
                                     diff_fake)
Beispiel #3
0
def post_process(im, generated_face, expanded_bbox, original_bbox, image_mask):
    generated_face = denormalize_img(generated_face)
    generated_face = torch_utils.image_to_numpy(
        generated_face[0], to_uint8=True)
    orig_imsize = expanded_bbox[2] - expanded_bbox[0]
    generated_face = cv2.resize(generated_face, (orig_imsize, orig_imsize))
    im = replace_face(im, generated_face, image_mask,
                      original_bbox, expanded_bbox)
    return im
Beispiel #4
0
def image_to_numpy(images, to_uint8=False, denormalize=False):
    single_image = False
    if len(images.shape) == 3:
        single_image = True
        images = images[None]
    if denormalize:
        images = data_utils.denormalize_img(images)
    images = images.detach().cpu().numpy()
    r, g, b = images[:, 0], images[:, 1], images[:, 2]
    images = np.stack((r, g, b), axis=3)
    if to_uint8:
        images = (images * 255).astype(np.uint8)
    if single_image:
        return images[0]
    return images
Beispiel #5
0
from deep_privacy.data_tools.dataloaders import load_dataset
from deep_privacy import torch_utils
from deep_privacy.data_tools import data_utils

start_imsize = 8
batch_size = 32

dl_train, dl_val = load_dataset("fdf", batch_size, start_imsize, False, 14,
                                True)

dl = dl_val
dl.update_next_transition_variable(1.0)
next(iter(dl))

for im, condition, landmark in dl:
    im = data_utils.denormalize_img(im)
    im = torch_utils.image_to_numpy(im, to_uint8=True)
    to_save1 = im
    break
to_save1 = np.concatenate(to_save1, axis=1)
dl_train, dl_val = load_dataset("fdf", batch_size, start_imsize * 2, False, 14,
                                True)
dl = dl_val
dl.update_next_transition_variable(0.0)
next(iter(dl))

for im, condition, landmark in dl:
    im = torch.nn.functional.avg_pool2d(im, 2)
    im = data_utils.denormalize_img(im)
    im = torch_utils.image_to_numpy(im, to_uint8=True)
    to_save2 = im
Beispiel #6
0
                           full_validation=False,
                           pose_size=14,
                           load_fraction=True)
config = config_parser.load_config("models/minibatch_std/config.yml")
ckpt = torch.load("models/minibatch_std/transition_checkpoints/imsize64.ckpt")
discriminator, generator = init_model(config.models.pose_size,
                                      config.models.start_channel_size,
                                      config.models.image_channels,
                                      config.models.discriminator.structure)
generator.load_state_dict(ckpt["G"])
generator.cuda()
print(generator.network.current_imsize)
dl_train.update_next_transition_variable(1.0)
ims, conditions, landmarks = next(iter(dl_train))

fakes = denormalize_img(generator(conditions, landmarks))
os.makedirs(".debug", exist_ok=True)
torchvision.utils.save_image(fakes, ".debug/test.jpg")

# Extend
generator.extend()
generator.cuda()
generator.transition_value = 0.0
dl_train, _ = load_dataset("yfcc100m128",
                           batch_size=64,
                           imsize=128,
                           full_validation=False,
                           pose_size=14,
                           load_fraction=True)
dl_train.update_next_transition_variable(0.0)
ims, conditions, landmarks = next(iter(dl_train))