Exemplo n.º 1
0
def get_gif_from_list_of_params(generator, flame_params, step, alpha, noise, overlay_landmarks, flame_std, flame_mean,
                                overlay_visualizer, rendered_flame_as_condition, use_posed_constant_input,
                                normal_maps_as_cond, camera_params):
    # cam_t = np.array([0., 0., 2.5])
    # camera_params = camera_dynamic((224, 224), cam_t)
    if overlay_visualizer is None:
        overlay_visualizer = OverLayViz()

    fixed_embeddings = torch.ones(flame_params.shape[0], dtype=torch.long, device='cuda')*13
    # print(generator.module.get_embddings()[fixed_embeddings])
    flame_params_unnorm = flame_params * flame_std + flame_mean

    flame_params_unnorm = torch.from_numpy(flame_params_unnorm).cuda()
    normal_map_img, _, _, _, rend_imgs = \
        overlay_visualizer.get_rendered_mesh(flame_params=(flame_params_unnorm[:, SHAPE_IDS[0]:SHAPE_IDS[1]],
                                                           flame_params_unnorm[:, EXP_IDS[0]:EXP_IDS[1]],
                                                           flame_params_unnorm[:, POSE_IDS[0]:POSE_IDS[1]],
                                                           flame_params_unnorm[:, TRANS_IDS[0]:TRANS_IDS[1]]),
                                             camera_params=camera_params)
    rend_imgs = (rend_imgs/127.0 - 1)

    if use_posed_constant_input:
        pose = flame_params[:, constants.get_idx_list('GLOBAL_ROT')]
    else:
        pose = None

    if rendered_flame_as_condition:
        gen_in = rend_imgs
    else:
        gen_in = flame_params

    if normal_maps_as_cond:
        gen_in = torch.cat((rend_imgs, normal_map_img), dim=1)

    fake_images = generate_from_flame_sequence(generator, gen_in, pose, step, alpha, noise,
                                               input_indices=fixed_embeddings)[-1]

    fake_images = overlay_visualizer.range_normalize_images(fast_image_reshape(fake_images,
                                                                               height_out=rend_imgs.shape[2],
                                                                               width_out=rend_imgs.shape[3]))
    if rendered_flame_as_condition:
        fake_images = torch.cat([fake_images.cpu(), (rend_imgs.cpu() + 1)/2], dim=-1)

    if normal_maps_as_cond:
        fake_images = torch.cat([fake_images.cpu(), (normal_map_img.cpu() + 1) / 2], dim=-1)

    return fake_images
Exemplo n.º 2
0
class VisualizationSaver():
    def __init__(self, gen_i, gen_j, sampling_flame_labels, dataset, input_indices, overlay_mesh=False):
        self.gen_i = gen_i
        self.gen_j = gen_j
        self.sampling_flame_labels = sampling_flame_labels
        self.overlay_mesh = overlay_mesh
        self.overlay_visualizer = OverLayViz()
        self.dataset = dataset
        self.input_indices = input_indices
        self.cam_t = np.array([0., 0., 2.5])

    def set_flame_params(self, pose, sampling_flame_labels, input_indices):
        self.pose = pose
        self.sampling_flame_labels = sampling_flame_labels
        self.input_indices = input_indices

    def save_samples(self, i, model, step, alpha, resolution, fid, run_id):
        images = []
        # camera_params = camera_dynamic((resolution, resolution), self.cam_t)
        flength = 5000
        cam_t = np.array([0., 0., 0])
        camera_params = camera_ringnetpp((512, 512), trans=cam_t, focal=flength)

        with torch.no_grad():
            for img_idx in range(self.gen_i):
                flame_param_this_batch = self.sampling_flame_labels[img_idx * self.gen_j:(img_idx + 1) * self.gen_j]
                if self.pose is not None:
                    pose_this_batch = self.pose[img_idx * self.gen_j:(img_idx + 1) * self.gen_j]
                else:
                    pose_this_batch = None
                idx_this_batch = self.input_indices[img_idx * self.gen_j:(img_idx + 1) * self.gen_j]
                img_tensor = model(flame_param_this_batch.clone(), pose_this_batch, step=step, alpha=alpha,
                                   input_indices=idx_this_batch)[-1]

                img_tensor = self.overlay_visualizer.range_normalize_images(
                    dataset_loaders.fast_image_reshape(img_tensor, height_out=256, width_out=256,
                                                       non_diff_allowed=True))

                images.append(img_tensor.data.cpu())

        torchvision.utils.save_image(
            torch.cat(images, 0),
            f'{cnst.output_root}sample/{str(run_id)}/{str(i + 1).zfill(6)}_res{resolution}x{resolution}_fid_{fid:.2f}.png',
            nrow=self.gen_i,
            normalize=True,
            range=(0, 1))