Пример #1
0
def infer_images(
        dataloader, generator, truncation_level: float,
        verbose=False,
        return_condition=False) -> List[np.ndarray]:
    imshape = (generator.current_imsize, generator.current_imsize, 3)
    real_images = np.empty(
        (dataloader.num_images(), *imshape), dtype=np.float32)
    fake_images = np.empty_like(real_images)
    if return_condition:
        conditions = np.empty_like(fake_images)
    batch_size = dataloader.batch_size
    generator.eval()
    dl_iter = iter(dataloader)
    if verbose:
        import tqdm
        dl_iter = tqdm.tqdm(dl_iter)
    with torch.no_grad():
        for idx, batch in enumerate(dl_iter):
            real_data = batch["img"]
            z = truncated_z(real_data, generator.z_shape, truncation_level)
            fake_data = generator(**batch, z=z)
            start = idx * batch_size
            end = start + len(real_data)
            real_data = torch_utils.image_to_numpy(real_data, denormalize=True)
            fake_data = torch_utils.image_to_numpy(fake_data, denormalize=True)
            real_images[start:end] = real_data
            fake_images[start:end] = fake_data
            if return_condition:
                conditions[start:end] = torch_utils.image_to_numpy(
                    batch["condition"], denormalize=True)
    generator.train()
    if return_condition:
        return real_images, fake_images, conditions
    return real_images, fake_images
Пример #2
0
 def save_debug_images(self, face_info, generated_faces):
     for face_idx, info in face_info.items():
         torch_input = info["torch_input"].squeeze(0)
         generated_face = generated_faces[face_idx]
         torch_input = torch_utils.image_to_numpy(torch_input,
                                                  to_uint8=True,
                                                  denormalize=True)
         generated_face = torch_utils.image_to_numpy(generated_face,
                                                     to_uint8=True,
                                                     denormalize=True)
         to_save = np.concatenate((torch_input, generated_face), axis=1)
         filepath = os.path.join(self.debug_directory,
                                 f"face_{face_idx}.jpg")
         cv2.imwrite(filepath, to_save[:, :, ::-1])
Пример #3
0
def inpaint_images(images: np.ndarray, masks: np.ndarray, generator):
    z = None
    fakes = torch.zeros(
        (images.shape[0], images.shape[-1], images.shape[1], images.shape[2]),
        dtype=torch.float32)
    masks = pre_process_masks(masks)
    inputs = [im * mask for im, mask in zip(images, masks)]
    images = [
        torch_utils.image_to_torch(im, cuda=False, normalize_img=True)
        for im in images
    ]
    masks = [torch_utils.mask_to_torch(mask, cuda=False) for mask in masks]
    with torch.no_grad():
        for idx, (im, mask) in enumerate(
                tqdm.tqdm(zip(images, masks), total=len(images))):
            im = torch_utils.to_cuda(im)
            mask = torch_utils.to_cuda(mask)
            assert im.shape[0] == mask.shape[0]
            assert im.shape[2:] == mask.shape[2:],\
                f"im shape: {im.shape}, mask shape: {mask.shape}"
            z = truncated_z(im, generator.z_shape, 0)
            condition = mask * im
            fake = generator(condition, mask, z)
            fakes[idx:(idx + 1)] = fake.cpu()
    fakes = torch_utils.image_to_numpy(fakes, denormalize=True) * 255
    return fakes, inputs
Пример #4
0
def post_process(im, generated_face, expanded_bbox, original_bbox, image_mask):
    generated_face = denormalize_img(generated_face)
    generated_face = torch_utils.image_to_numpy(
        generated_face[0], to_uint8=True)
    orig_imsize = expanded_bbox[2] - expanded_bbox[0]
    generated_face = cv2.resize(generated_face, (orig_imsize, orig_imsize))
    im = replace_face(im, generated_face, image_mask,
                      original_bbox, expanded_bbox)
    return im
Пример #5
0
    def validate_model(self):
        real_scores = []
        fake_scores = []
        wasserstein_distances = []
        epsilon_penalties = []
        self.running_average_generator.eval()
        self.discriminator.eval()
        real_images = torch.zeros((len(self.dataloader_val)*self.batch_size,
                                   3,
                                   self.current_imsize,
                                   self.current_imsize))
        fake_images = torch.zeros((len(self.dataloader_val)*self.batch_size,
                                   3,
                                   self.current_imsize,
                                   self.current_imsize))
        with torch.no_grad():
            self.dataloader_val.update_next_transition_variable(
                self.transition_variable)
            for idx, (real_data, condition, landmarks) in enumerate(tqdm.tqdm(self.dataloader_val, desc="Validating model!")):
                fake_data = self.running_average_generator(condition,
                                                           landmarks)
                real_score = self.discriminator(
                    real_data, condition, landmarks)
                fake_score = self.discriminator(fake_data, condition,
                                                landmarks)
                wasserstein_distance = (real_score - fake_score).squeeze()
                epsilon_penalty = (real_score**2).squeeze()
                real_scores.append(real_score.mean().item())
                fake_scores.append(fake_score.mean().item())
                wasserstein_distances.append(
                    wasserstein_distance.mean().item())
                epsilon_penalties.append(
                    epsilon_penalty.mean().detach().item())

                start_idx = idx*self.batch_size
                end_idx = (idx+1)*self.batch_size
                real_images[start_idx:end_idx] = real_data.cpu().float()
                fake_images[start_idx:end_idx] = fake_data.cpu().float()
                del real_data, fake_data, real_score, fake_score, wasserstein_distance, epsilon_penalty
        real_images = torch_utils.image_to_numpy(real_images, to_uint8=False,
                                                 denormalize=True)
        fake_images_numpy = torch_utils.image_to_numpy(fake_images, to_uint8=False,
                                                 denormalize=True)
        fid_name = "{}_{}_{}".format(self.dataset,
                                     self.full_validation,
                                     self.current_imsize)
        if self.current_imsize >= 64:
            fid_val = fid.calculate_fid(real_images,
                                        fake_images_numpy, 
                                        False, 8, fid_name)
            self.logger.log_variable("stats/fid", np.mean(fid_val), True)
        self.logger.log_variable('discriminator/wasserstein-distance',
                                 np.mean(wasserstein_distances), True)
        self.logger.log_variable("discriminator/real-score",
                                 np.mean(real_scores), True)
        self.logger.log_variable("discriminator/fake-score",
                                 np.mean(fake_scores), True)
        self.logger.log_variable("discriminator/epsilon-penalty",
                                 np.mean(epsilon_penalties), True)
        self.logger.save_images("fakes", fake_images[:64],
                                log_to_validation=True)
        self.discriminator.train()
        self.generator.train()
Пример #6
0
            im = orig.copy()

            p = percentages[i]
            bbox = bounding_boxes[idx].clone().float()
            width = bbox[2] - bbox[0]
            height = bbox[3] - bbox[1]
            bbox[0] = bbox[0] - p * width
            bbox[2] = bbox[2] + p * width
            bbox[1] = bbox[1] - p * height
            bbox[3] = bbox[3] + p * height
            bbox = bbox.long()

            im = cut_bounding_box(im, bbox, generator.transition_value)
            orig_to_save = im.copy()

            im = torch_utils.image_to_torch(im, cuda=True, normalize_img=True)
            im = generator(im, pose, z.clone())
            im = torch_utils.image_to_numpy(im.squeeze(),
                                            to_uint8=True,
                                            denormalize=True)

            im = np.concatenate((orig_to_save.astype(np.uint8), im), axis=0)
            to_save = np.concatenate((to_save, im), axis=1)
        ims_to_save.append(to_save)
    savepath = os.path.join(savedir, f"result_image.jpg")

    ims_to_save = np.concatenate(ims_to_save, axis=0)
    plt.imsave(savepath, ims_to_save)

    print("Results saved to:", savedir)
Пример #7
0
generator, imsize, save_path, pose_size = read_args()

batch_size = 128
dataloader_train, dataloader_val = load_dataset("fdf", batch_size, 128, True, pose_size, True )
dataloader_val.update_next_transition_variable(1.0)
fake_images = np.zeros((len(dataloader_val)*batch_size, imsize, imsize, 3),
                       dtype=np.uint8)
real_images = np.zeros((len(dataloader_val)*batch_size, imsize, imsize, 3),
                       dtype=np.uint8)
z = generator.generate_latent_variable(batch_size, "cuda", torch.float32).zero_()
with torch.no_grad():
    for idx, (real_data, condition, landmarks) in enumerate(tqdm.tqdm(dataloader_val)):

        fake_data = generator(condition, landmarks, z.clone())
        fake_data = torch_utils.image_to_numpy(fake_data, to_uint8=True, denormalize=True)
        real_data = torch_utils.image_to_numpy(real_data, to_uint8=True, denormalize=True)
        start_idx = idx * batch_size
        end_idx = (idx+1) * batch_size

        real_images[start_idx:end_idx] = real_data
        fake_images[start_idx:end_idx] = fake_data

generator.cpu()
del generator

if os.path.isdir(save_path):
    shutil.rmtree(save_path)

os.makedirs(os.path.join(save_path, "real"))
os.makedirs(os.path.join(save_path, "fake"))
Пример #8
0
from deep_privacy import torch_utils
from deep_privacy.data_tools import data_utils

start_imsize = 8
batch_size = 32

dl_train, dl_val = load_dataset("fdf", batch_size, start_imsize, False, 14,
                                True)

dl = dl_val
dl.update_next_transition_variable(1.0)
next(iter(dl))

for im, condition, landmark in dl:
    im = data_utils.denormalize_img(im)
    im = torch_utils.image_to_numpy(im, to_uint8=True)
    to_save1 = im
    break
to_save1 = np.concatenate(to_save1, axis=1)
dl_train, dl_val = load_dataset("fdf", batch_size, start_imsize * 2, False, 14,
                                True)
dl = dl_val
dl.update_next_transition_variable(0.0)
next(iter(dl))

for im, condition, landmark in dl:
    im = torch.nn.functional.avg_pool2d(im, 2)
    im = data_utils.denormalize_img(im)
    im = torch_utils.image_to_numpy(im, to_uint8=True)
    to_save2 = im
    break
    def anonymize_images(self,
                         images: np.ndarray,
                         image_annotations: typing.List[ImageAnnotation]
                         ) -> typing.List[np.ndarray]:
        anonymized_images = []
        for im_idx, image_annotation in enumerate(image_annotations):
            # pre-process
            imsize = self.inference_imsize
            condition = torch.zeros(
                (len(image_annotation), 3, imsize, imsize),
                dtype=torch.float32)
            mask = torch.zeros((len(image_annotation), 1, imsize, imsize))
            landmarks = torch.empty(
                (len(image_annotation), self.pose_size), dtype=torch.float32)
            for face_idx in range(len(image_annotation)):
                face, mask_ = image_annotation.get_face(face_idx, imsize)
                condition[face_idx] = torch_utils.image_to_torch(
                    face, cuda=False, normalize_img=True
                )
                mask[face_idx, 0] = torch.from_numpy(mask_).float()
                kp = image_annotation.aligned_keypoint(face_idx)
                landmarks[face_idx] = kp[:, :self.pose_size]
            img = condition
            condition = condition * mask
            z = infer.truncated_z(
                condition, self.cfg.models.generator.z_shape,
                self.truncation_level)
            batches = dict(
                condition=condition,
                mask=mask,
                landmarks=landmarks,
                z=z,
                img=img
            )
            # Inference
            anonymized_faces = np.zeros((
                len(image_annotation), imsize, imsize, 3), dtype=np.float32)
            for idx, batch in enumerate(
                    batched_iterator(batches, self.batch_size)):
                face = self._get_face(batch)
                face = torch_utils.image_to_numpy(
                    face, to_uint8=False, denormalize=True)
                start = idx * self.batch_size
                anonymized_faces[start:start + self.batch_size] = face
            anonymized_image = image_annotation.stitch_faces(anonymized_faces)
            anonymized_images.append(anonymized_image)
            if self.save_debug:
                num_faces = len(batches["condition"])
                for face_idx in range(num_faces):
                    orig_face = torch_utils.image_to_numpy(
                        batches["img"][face_idx], denormalize=True, to_uint8=True)
                    condition = torch_utils.image_to_numpy(
                        batches["condition"][face_idx],
                        denormalize=True, to_uint8=True)
                    fake_face = anonymized_faces[face_idx]
                    fake_face = (fake_face * 255).astype(np.uint8)
                    to_save = np.concatenate(
                        (orig_face, condition, fake_face), axis=1)
                    filepath = self.debug_directory.joinpath(
                        f"im{im_idx}_face{face_idx}.png")
                    cv2.imwrite(str(filepath), to_save[:, :, ::-1])

        return anonymized_images