Example #1
0
    def _calc_detail_info(self, theta):
        cam = theta[:, 0:3].contiguous()
        pose = theta[:, 3:75].contiguous()
        shape = theta[:, 75:].contiguous()
        verts, j3d, Rs = self.smpl(beta=shape, theta=pose, get_skin=True)
        j2d = util.batch_orth_proj(j3d, cam)

        return (theta, verts, j2d, j3d, Rs)
Example #2
0
    def forward(self, cam, vertices, images, cam_new):
        full_vertices, N_bd = get_full_verts(vertices)
        t_vertices = util.batch_orth_proj(full_vertices, cam)
        t_vertices[..., 1:] = -t_vertices[..., 1:]
        t_vertices[..., 2] = t_vertices[..., 2] + 10
        t_vertices = image_meshing(t_vertices, N_bd)
        t_vertices[..., :2] = torch.clamp(t_vertices[..., :2], -1, 1)
        t_vertices[:, :, 2] = t_vertices[:, :, 2] - 9
        batch_size = vertices.shape[0]
        ## rasterizer near 0 far 100. move mesh so minz larger than 0
        uvcoords = t_vertices.clone()
        # Attributes
        uvcoords = torch.cat(
            [uvcoords[:, :, :2], uvcoords[:, :, 0:1] * 0. + 1.],
            -1)  # [bz, ntv, 3]
        face_vertices = util.face_vertices(
            uvcoords, self.faces.expand(batch_size, -1, -1))
        # render
        attributes = face_vertices.detach()
        full_vertices, N_bd = get_full_verts(vertices)
        transformed_vertices = util.batch_orth_proj(full_vertices, cam_new)
        transformed_vertices[..., 1:] = -transformed_vertices[..., 1:]
        transformed_vertices[..., 2] = transformed_vertices[..., 2] + 10
        transformed_vertices = image_meshing(transformed_vertices, N_bd)
        transformed_vertices[..., :2] = torch.clamp(
            transformed_vertices[..., :2], -1, 1)
        rendering = self.rasterizer(transformed_vertices,
                                    self.faces.expand(batch_size, -1, -1),
                                    attributes)

        alpha_images = rendering[:, -1, :, :][:, None, :, :].detach()

        # albedo
        uvcoords_images = rendering[:, :3, :, :]
        grid = (uvcoords_images).permute(0, 2, 3, 1)[:, :, :, :2]

        results = F.grid_sample(images, grid, align_corners=False)
        return {'rotate_images': results}
Example #3
0
    def _calc_detail_info(
            self, theta):  # theta = big theta that includes cam, pose, shape
        # print(f'len theta is {theta.size()}')
        cam = theta[:, 0:3].contiguous()
        pose = theta[:, 3:75].contiguous()  # 72 (24, 3)
        shape = theta[:, 75:].contiguous()

        _, verts, j3d, Rs = self.smpl(betas=shape,
                                      body_pose=pose,
                                      get_skin=True)
        # verts, j3d, Rs is (torch.Size([8, 6890, 3]), torch.Size([8, 45, 3]), torch.Size([8, 72]))

        j2d = util.batch_orth_proj(j3d, cam)

        return (theta, verts, j2d, j3d, Rs)
    def render_tex_and_normal(self, shapecode, expcode, posecode, texcode,
                              lightcode, cam):
        verts, _, _ = self.flame(shape_params=shapecode,
                                 expression_params=expcode,
                                 pose_params=posecode)
        trans_verts = util.batch_orth_proj(verts, cam)
        trans_verts[:, :, 1:] = -trans_verts[:, :, 1:]

        albedos = self.flametex(texcode)
        rendering_results = self.render(verts,
                                        trans_verts,
                                        albedos,
                                        lights=lightcode)
        textured_images, normals = rendering_results[
            'images'], rendering_results['normals']
        normal_images = self.render.render_normal(trans_verts, normals)
        return textured_images, normal_images
    def render_tex_and_normal(self, shapecode, expcode, posecode, texcode, lightcode, cam, constant_albedo=None):
        verts, _, _ = self.flame(shape_params=shapecode, expression_params=expcode, pose_params=posecode)
        trans_verts = util.batch_orth_proj(verts, cam)
        trans_verts[:, :, 1:] = -trans_verts[:, :, 1:]

        if constant_albedo is None:
            albedos = self.flametex(texcode)
        else:
            albedos = \
                torch.tensor([constant_albedo, constant_albedo,
                              constant_albedo], dtype=torch.float32)[None, ..., None, None].cuda()
            albedos = albedos.repeat(texcode.shape[0], 1, 256, 256)

        rendering_results = self.render(verts, trans_verts, albedos, lights=lightcode)
        textured_images, normals = rendering_results['images'], rendering_results['normals']
        normal_images = self.render.render_normal(trans_verts, normals)
        return textured_images, normal_images
Example #6
0
def train_parallel(rank, world_size, args, loader, dfr, flame, flametex,
                   optim):
    # def train_parallel(rank, world_size, args, loader, dfr, optim):

    config = get_config()
    loader = sample_data(loader)
    pbar = range(args.iter)

    print("Rank: ", get_rank(), ", rank=", rank)
    os.environ['MASTER_ADDR'] = 'localhost'
    os.environ['MASTER_PORT'] = '12355'
    dist.init_process_group('nccl', rank=rank, world_size=world_size)
    print(f"{rank + 1}/{world_size} process initialized.")

    kwargs_ddp = {'device_ids': [rank]}
    dfr = DDP(dfr, **kwargs_ddp)
    flame = DDP(flame, **kwargs_ddp)
    flametex = DDP(flametex, **kwargs_ddp)
    #     flame = FLAME(config).cuda(rank)
    #     flametex = FLAMETex(config).cuda(rank)
    #     cam.cuda(rank)

    if rank == 0:
        pbar = tqdm(pbar, initial=0, dynamic_ncols=True, smoothing=0.01)


#     if args.distributed:
#         dfr_module = dfr.module
#         flame_module = flame.module
#         flametex_module = flametex.module
#     else:
#         dfr_module = dfr
#         flame_module = flame
#         flametex_module = flametex

    bz = args.batch_size
    tex = nn.Parameter(torch.zeros(bz, config.tex_params).float().cuda(rank))
    cam = torch.zeros(bz, config.camera_params)
    cam[:, 0] = 5.0
    cam = nn.Parameter(cam.float().cuda(rank))
    lights = nn.Parameter(torch.zeros(bz, 9, 3).float().cuda(rank))

    for idx in pbar:
        #         if i > args.iter:
        #             print('Done training!')
        #             break
        #         print("Rank: ", get_rank(), " , epoch: ", idx)

        #         example = next(loader)
        for example in loader:
            latents = example['latents'].cuda(rank)
            landmarks_gt = example['landmarks_gt'].cuda(rank)
            images = example['images'].cuda(rank)
            image_masks = example['image_masks'].cuda(rank)

            shape, expression, pose = dfr(latents.view(args.batch_size, -1))
            vertices, landmarks2d, landmarks3d = flame(
                shape_params=shape,
                expression_params=expression,
                pose_params=pose)

            trans_vertices = util.batch_orth_proj(vertices, cam)
            trans_vertices[..., 1:] = -trans_vertices[..., 1:]
            landmarks2d = util.batch_orth_proj(landmarks2d, cam)
            landmarks2d[..., 1:] = -landmarks2d[..., 1:]
            landmarks3d = util.batch_orth_proj(landmarks3d, cam)
            landmarks3d[..., 1:] = -landmarks3d[..., 1:]

            losses = {}
            losses['landmark'] = util.l2_distance(
                landmarks2d[:, :, :2],
                landmarks_gt[:, :, :2]) * 1  #config.w_lmks

            all_loss = 0.
            for key in losses.keys():
                all_loss = all_loss + losses[key]
    #                 losses_to_plot[key].append(losses[key].item()) # Store for plotting later.

            losses['all_loss'] = all_loss
            #             losses_to_plot['all_loss'].append(losses['all_loss'].item())

            optim.zero_grad()
            all_loss.backward()
            optim.step()

            if get_rank() == 0:
                pbar.set_description((
                    f"total: {losses['all_loss']:.4f}; landmark: {losses['landmark']:.4f};"
                ))
Example #7
0
def train(args, config, loader, dfr, flame, flametex, render, tex_mean,
          device):
    from datetime import datetime
    now = datetime.now()
    dt_string = now.strftime("%d-%m-%Y_%H.%M.%S")  # dd/mm/YY H:M:S
    savefolder = os.path.sep.join(['./test_results', f'{dt_string}'])
    if not os.path.exists(savefolder):
        os.makedirs(savefolder, exist_ok=True)

#     lights = nn.Parameter(torch.zeros(args.batch_size, 9, 3).float().to(device))
#     optim = torch.optim.Adam(
#                 list(dfr.parameters()) + [lights],
#                 lr=config.e_lr,
#                 weight_decay=config.e_wd
#     )

#     cam = torch.zeros(args.batch_size, config.camera_params).to(device)
#     cam[:, 0] = 5.0
#     optim = torch.optim.Adam(
#                 list(dfr.parameters()) + [cam],
#                 lr=config.e_lr,
#                 weight_decay=config.e_wd
#     )
    optim = torch.optim.Adam(
        dfr.parameters(),
        lr=1e-4,
        weight_decay=0.00001  # config.e_wd
    )
    #     optim = torch.optim.SGD(dfr.parameters(), lr=0.01, momentum=0.9) # Produces NaNs
    #     optim = torch.optim.RMSprop(params, lr=0.01)
    #     scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optim, 'min')

    loader = sample_data(loader)
    #     pbar = range(args.iter)

    losses_to_plot = {}
    losses_to_plot['all_loss'] = []
    losses_to_plot['landmark_2d'] = []
    losses_to_plot['landmark_3d'] = []
    losses_to_plot['shape_reg'] = []
    losses_to_plot['shape_reg'] = []
    losses_to_plot['expression_reg'] = []
    losses_to_plot['pose_reg'] = []
    losses_to_plot['photometric_texture'] = []
    losses_to_plot['texture_reg'] = []

    loss_mse = nn.MSELoss()

    idx_rigid_stop = args.iter_rigid
    modulo_save_imgs = args.iter_save_img
    modulo_save_model = args.iter_save_chkpt

    pbar = tqdm(range(0, idx_rigid_stop), dynamic_ncols=True, smoothing=0.01)
    k = 0
    for k in pbar:
        for example in dataloader:
            latents = example['latents'].to(device)
            landmarks_2d_gt = example['landmarks_2d_gt'].to(device)
            images = example['images'].to(device)
            image_masks = example['image_masks'].to(device)

            #             shape, expression, pose, tex, cam, lights = dfr(latents.view(args.batch_size, -1))
            shape, expression, pose, tex, cam, lights = dfr(
                latents.view(args.batch_size, -1))
            vertices, landmarks2d, landmarks3d = flame(
                shape_params=shape,
                expression_params=expression,
                pose_params=pose)

            trans_vertices = util.batch_orth_proj(vertices, cam)
            trans_vertices[..., 1:] = -trans_vertices[..., 1:]
            landmarks2d = util.batch_orth_proj(landmarks2d, cam)
            landmarks2d[..., 1:] = -landmarks2d[..., 1:]
            landmarks3d = util.batch_orth_proj(landmarks3d, cam)
            landmarks3d[..., 1:] = -landmarks3d[..., 1:]

            losses = {}
            losses['landmark_2d'] = util.l2_distance(
                landmarks2d[:, 17:, :2],
                landmarks_2d_gt[:, 17:, :2]) * config.w_lmks

            #             losses['pose_reg'] = (torch.sum(pose ** 2) / 2) * 1e-4 #config.w_pose_reg

            all_loss = 0.
            for key in losses.keys():
                all_loss = all_loss + losses[key]
                losses_to_plot[key].append(
                    losses[key].item())  # Store for plotting later.

            losses['all_loss'] = all_loss
            losses_to_plot['all_loss'].append(losses['all_loss'].item())

            optim.zero_grad()
            all_loss.backward()
            optim.step()

            pbar.set_description((
                f"total: {losses['all_loss']:.4f}; landmark_2d: {losses['landmark_2d']:.4f}; "
            ))

            if (k % modulo_save_imgs == 0):
                try:
                    grids = {}
                    grids['images'] = torchvision.utils.make_grid(
                        images.detach().cpu())
                    grids['landmarks_2d_gt'] = torchvision.utils.make_grid(
                        util.tensor_vis_landmarks(images, landmarks_gt))
                    grids['landmarks2d'] = torchvision.utils.make_grid(
                        util.tensor_vis_landmarks(images, landmarks2d))
                    grids['landmarks3d'] = torchvision.utils.make_grid(
                        util.tensor_vis_landmarks(images, landmarks3d))

                    grid = torch.cat(list(grids.values()), 1)
                    grid_image = (grid.numpy().transpose(1, 2, 0).copy() *
                                  255)[:, :, [2, 1, 0]]
                    grid_image = np.minimum(np.maximum(grid_image, 0),
                                            255).astype(np.uint8)
                    cv2.imwrite(
                        '{}/{}.jpg'.format(savefolder,
                                           str(k).zfill(6)), grid_image)
                except:
                    print("Error saving images... continuing")
                    continue

            if k % modulo_save_model == 0:
                save_checkpoint(path=savefolder,
                                epoch=k + 1,
                                losses=losses_to_plot,
                                model=dfr)

    # Save final epoch for rigid fitting.
    #
    if k > 0:
        save_checkpoint(path=savefolder,
                        epoch=k + 1,
                        losses=losses_to_plot,
                        model=dfr)

    # Second stage training. Adding in photometric loss.
    #
    pbar = tqdm(range(idx_rigid_stop, args.iter),
                dynamic_ncols=True,
                smoothing=0.01)
    for k in pbar:
        for example in dataloader:
            latents = example['latents'].to(device)
            landmarks_2d_gt = example['landmarks_2d_gt'].to(device)
            landmarks_3d_gt = example['landmarks_3d_gt'].to(device)
            images = example['images'].to(device)
            image_masks = example['image_masks'].to(device)

            #             shape, expression, pose, tex, cam, lights = dfr(latents.view(args.batch_size, -1))
            shape, expression, pose, tex, cam, lights = dfr(
                latents.view(args.batch_size, -1))
            vertices, landmarks2d, landmarks3d = flame(
                shape_params=shape,
                expression_params=expression,
                pose_params=pose)

            trans_vertices = util.batch_orth_proj(vertices, cam)
            trans_vertices[..., 1:] = -trans_vertices[..., 1:]
            landmarks2d = util.batch_orth_proj(landmarks2d, cam)
            landmarks2d[..., 1:] = -landmarks2d[..., 1:]
            landmarks3d = util.batch_orth_proj(landmarks3d, cam)
            landmarks3d[..., 1:] = -landmarks3d[..., 1:]

            losses = {}

            #             if k < 250:
            #                 losses['landmark_2d'] = util.l2_distance(landmarks2d[:, 17:, :2],
            #                                                       landmarks_2d_gt[:, 17:, :2]) * 2.0 #config.w_lmks
            #             else:
            #                 losses['landmark_2d'] = util.l2_distance(landmarks2d[:, :, :2],
            #                                                       landmarks_2d_gt[:, :, :2]) * 2.0
            losses['landmark_2d'] = util.l2_distance(
                landmarks2d[:, :, :2], landmarks_2d_gt[:, :, :2]) * 2.0

            losses['landmark_3d'] = util.l2_distance(
                landmarks3d[:, :, :2], landmarks_3d_gt[:, :, :2]) * 1.0
            losses['shape_reg'] = (torch.sum(shape**2) /
                                   2) * config.w_shape_reg  # *1e-4
            losses['expression_reg'] = (torch.sum(expression**2) /
                                        2) * config.w_expr_reg  # *1e-4
            losses['pose_reg'] = (torch.sum(pose**2) / 2) * config.w_pose_reg

            ## render
            albedos = flametex(tex) / 255.
            losses['texture_reg'] = loss_mse(
                albedos,
                tex_mean.repeat(args.batch_size, 1, 1,
                                1))  #* 1e-3 # Regularize learned texture.
            ops = render(vertices, trans_vertices, albedos, lights)
            predicted_images = ops['images']
            losses['photometric_texture'] = (image_masks * (predicted_images - images).abs()).mean() \
                                            * config.w_pho

            all_loss = 0.
            for key in losses.keys():
                all_loss = all_loss + losses[key]
                losses_to_plot[key].append(
                    losses[key].item())  # Store for plotting later.

            losses['all_loss'] = all_loss
            losses_to_plot['all_loss'].append(losses['all_loss'].item())

            optim.zero_grad()
            all_loss.backward()
            optim.step()
            #             scheduler.step(all_loss)

            pbar.set_description((
                f"total: {losses['all_loss']:.4f}; landmark_2d: {losses['landmark_2d']:.4f}; "
                f"landmark_3d: {losses['landmark_3d']:.4f}; "
                f"shape: {losses['shape_reg']:.4f}; express: {losses['expression_reg']:.4f}; "
                f"photo: {losses['photometric_texture']:.4f}; "))

            # visualize
            if k % modulo_save_imgs == 0:
                shape_images = render.render_shape(vertices, trans_vertices,
                                                   images)
                save_rendered_imgs(savefolder, k, images, predicted_images,
                                   shape_images, albedos, ops, landmarks_2d_gt,
                                   landmarks2d, landmarks3d)
#                 try:
# #                     grids = {}
# #     #                 visind = range(bz)  # [0]
# #                     grids['images'] = torchvision.utils.make_grid(images).detach().cpu()
# #                     grids['landmarks_gt'] = torchvision.utils.make_grid(
# #                         util.tensor_vis_landmarks(images.clone().detach(), landmarks_gt))
# #                     grids['landmarks2d'] = torchvision.utils.make_grid(
# #                         util.tensor_vis_landmarks(images, landmarks2d))
# #                     grids['landmarks3d'] = torchvision.utils.make_grid(
# #                         util.tensor_vis_landmarks(images, landmarks3d))
# #                     grids['albedoimage'] = torchvision.utils.make_grid(
# #                         (ops['albedo_images']).detach().cpu())
# #                     grids['render'] = torchvision.utils.make_grid(predicted_images.detach().float().cpu())
# #                     shape_images = render.render_shape(vertices, trans_vertices, images)
# #                     grids['shape'] = torchvision.utils.make_grid(
# #                         F.interpolate(shape_images, [224, 224])).detach().float().cpu()

# #                     grids['tex'] = torchvision.utils.make_grid(F.interpolate(albedos, [224, 224])).detach().cpu()
# #                     grid = torch.cat(list(grids.values()), 1)
# #                     grid_image = (grid.numpy().transpose(1, 2, 0).copy() * 255)[:, :, [2, 1, 0]]
# #                     grid_image = np.minimum(np.maximum(grid_image, 0), 255).astype(np.uint8)

# #                     cv2.imwrite('{}/{}.jpg'.format(savefolder, str(k).zfill(6)), grid_image)

#                     shape_images = render.render_shape(vertices, trans_vertices, images)
#                     save_rendered_imgs(savefolder, k, images, predicted_images, shape_images, albedos, ops,
#                                        landmarks_gt, landmarks2d, landmarks3d)
#                 except:
#                     print("Error saving images and renderings... continuing")
#                     continue

            if k % modulo_save_model == 0:
                save_checkpoint(path=savefolder,
                                epoch=k + 1,
                                losses=losses_to_plot,
                                model=dfr)

    # Save final epoch renderings and checkpoints.
    #
    shape_images = render.render_shape(vertices, trans_vertices, images)
    save_rendered_imgs(savefolder, k + 1, images, predicted_images,
                       shape_images, albedos, ops, landmarks_2d_gt,
                       landmarks2d, landmarks3d)

    save_checkpoint(path=savefolder,
                    epoch=k + 1,
                    losses=losses_to_plot,
                    model=dfr)

    print("cam: ", cam)
    print("landmarks3d.mean: ", landmarks3d.mean())
    print("landmarks3d.min: ", landmarks3d.min())
    print("landmarks3d.max: ", landmarks3d.max())
Example #8
0
    def optimize(self, images, landmarks, image_masks, savefolder=None):
        bz = images.shape[0]
        pose = nn.Parameter(
            torch.zeros(bz, self.config.pose_params).float().to(self.device))
        exp = nn.Parameter(
            torch.zeros(bz,
                        self.config.expression_params).float().to(self.device))
        shape = nn.Parameter(
            torch.zeros(bz, self.config.shape_params).float().to(self.device))
        tex = nn.Parameter(
            torch.zeros(bz, self.config.tex_params).float().to(self.device))
        cam = torch.zeros(bz, self.config.camera_params)
        cam[:, 0] = 5.
        cam = nn.Parameter(cam.float().to(self.device))
        lights = nn.Parameter(torch.zeros(bz, 9, 3).float().to(self.device))

        e_opt = torch.optim.Adam([shape, exp, pose, cam, tex, lights],
                                 lr=self.config.e_lr,
                                 weight_decay=self.config.e_wd)

        e_opt_rigid = torch.optim.Adam([pose, cam],
                                       lr=self.config.e_lr,
                                       weight_decay=self.config.e_wd)

        gt_landmark = landmarks

        # rigid fitting of pose and camera with 51 static face landmarks,
        # this is due to the non-differentiable attribute of contour landmarks trajectory
        for k in range(200):
            losses = {}
            vertices, landmarks2d, landmarks3d = self.flame(
                shape_params=shape, expression_params=exp, pose_params=pose)
            trans_vertices = util.batch_orth_proj(vertices, cam)
            trans_vertices[..., 1:] = -trans_vertices[..., 1:]
            landmarks2d = util.batch_orth_proj(landmarks2d, cam)
            landmarks2d[..., 1:] = -landmarks2d[..., 1:]
            landmarks3d = util.batch_orth_proj(landmarks3d, cam)
            landmarks3d[..., 1:] = -landmarks3d[..., 1:]

            losses['landmark'] = util.l2_distance(
                landmarks2d[:, 17:, :2], gt_landmark[:,
                                                     17:, :2]) * config.w_lmks

            all_loss = 0.
            for key in losses.keys():
                all_loss = all_loss + losses[key]
            losses['all_loss'] = all_loss

            e_opt_rigid.zero_grad()
            all_loss.backward()
            e_opt_rigid.step()

            if self.config.verbose:
                loss_info = '----iter: {}, time: {}\n'.format(
                    k,
                    datetime.datetime.now().strftime('%Y-%m-%d-%H:%M:%S'))
                for key in losses.keys():
                    loss_info = loss_info + '{}: {}, '.format(
                        key, float(losses[key]))

                if k % self.config.print_freq == 0:
                    print(loss_info)

            if self.config.save_all_output and k % self.config.print_freq == 0:
                grids = {}
                visind = range(bz)  # [0]
                grids['images'] = torchvision.utils.make_grid(
                    images[visind]).detach().cpu()
                grids['landmarks_gt'] = torchvision.utils.make_grid(
                    util.tensor_vis_landmarks(images[visind],
                                              landmarks[visind]))
                grids['landmarks2d'] = torchvision.utils.make_grid(
                    util.tensor_vis_landmarks(images[visind],
                                              landmarks2d[visind]))
                grids['landmarks3d'] = torchvision.utils.make_grid(
                    util.tensor_vis_landmarks(images[visind],
                                              landmarks3d[visind]))

                grid = torch.cat(list(grids.values()), 1)
                grid_image = (grid.numpy().transpose(1, 2, 0).copy() *
                              255)[:, :, [2, 1, 0]]
                grid_image = np.minimum(np.maximum(grid_image, 0),
                                        255).astype(np.uint8)

                savefolder_iter = os.path.join(
                    savefolder, 'inter_steps'
                )  # use a separate folder to save intermediate results
                util.check_mkdir(savefolder_iter)
                cv2.imwrite('{}/{}.jpg'.format(savefolder_iter, k), grid_image)

        # non-rigid fitting of all the parameters with 68 face landmarks, photometric loss and regularization terms.
        for k in range(200, 1500):
            losses = {}
            vertices, landmarks2d, landmarks3d = self.flame(
                shape_params=shape, expression_params=exp, pose_params=pose)
            trans_vertices = util.batch_orth_proj(vertices, cam)
            trans_vertices[..., 1:] = -trans_vertices[..., 1:]
            landmarks2d = util.batch_orth_proj(landmarks2d, cam)
            landmarks2d[..., 1:] = -landmarks2d[..., 1:]
            landmarks3d = util.batch_orth_proj(landmarks3d, cam)
            landmarks3d[..., 1:] = -landmarks3d[..., 1:]

            losses['landmark'] = util.l2_distance(
                landmarks2d[:, :, :2], gt_landmark[:, :, :2]) * config.w_lmks
            losses['shape_reg'] = (torch.sum(shape**2) /
                                   2) * config.w_shape_reg  # *1e-4
            losses['expression_reg'] = (torch.sum(exp**2) /
                                        2) * config.w_expr_reg  # *1e-4
            losses['pose_reg'] = (torch.sum(pose**2) / 2) * config.w_pose_reg

            ## render
            albedos = self.flametex(tex) / 255.
            ops = self.render(vertices, trans_vertices, albedos, lights)
            predicted_images = ops['images']
            losses['photometric_texture'] = (
                image_masks *
                (ops['images'] - images).abs()).mean() * config.w_pho

            all_loss = 0.
            for key in losses.keys():
                all_loss = all_loss + losses[key]
            losses['all_loss'] = all_loss

            e_opt.zero_grad()
            all_loss.backward()
            e_opt.step()

            if self.config.verbose:
                loss_info = '----iter: {}, time: {}\n'.format(
                    k,
                    datetime.datetime.now().strftime('%Y-%m-%d-%H:%M:%S'))
                for key in losses.keys():
                    loss_info = loss_info + '{}: {}, '.format(
                        key, float(losses[key]))

                if k % self.config.print_freq == 0:
                    print(loss_info)

            # visualize
            if self.config.save_all_output and k % self.config.print_freq == 0:
                grids = {}
                visind = range(bz)  # [0]
                grids['images'] = torchvision.utils.make_grid(
                    images[visind]).detach().cpu()
                grids['landmarks_gt'] = torchvision.utils.make_grid(
                    util.tensor_vis_landmarks(images[visind],
                                              landmarks[visind]))
                grids['landmarks2d'] = torchvision.utils.make_grid(
                    util.tensor_vis_landmarks(images[visind],
                                              landmarks2d[visind]))
                grids['landmarks3d'] = torchvision.utils.make_grid(
                    util.tensor_vis_landmarks(images[visind],
                                              landmarks3d[visind]))
                grids['albedoimage'] = torchvision.utils.make_grid(
                    (ops['albedo_images'])[visind].detach().cpu())
                grids['render'] = torchvision.utils.make_grid(
                    predicted_images[visind].detach().float().cpu())
                shape_images = self.render.render_shape(
                    vertices, trans_vertices, images)
                grids['shape'] = torchvision.utils.make_grid(
                    F.interpolate(shape_images[visind],
                                  [224, 224])).detach().float().cpu()

                # grids['tex'] = torchvision.utils.make_grid(F.interpolate(albedos[visind], [224, 224])).detach().cpu()
                grid = torch.cat(list(grids.values()), 1)
                grid_image = (grid.numpy().transpose(1, 2, 0).copy() *
                              255)[:, :, [2, 1, 0]]
                grid_image = np.minimum(np.maximum(grid_image, 0),
                                        255).astype(np.uint8)

                savefolder_iter = os.path.join(
                    savefolder, 'inter_steps'
                )  # use a separate folder to save intermediate results
                util.check_mkdir(savefolder_iter)
                cv2.imwrite('{}/{}.jpg'.format(savefolder_iter, k), grid_image)

        single_params = {
            'shape': shape.detach().cpu().numpy(),
            'exp': exp.detach().cpu().numpy(),
            'pose': pose.detach().cpu().numpy(),
            'cam': cam.detach().cpu().numpy(),
            'verts': trans_vertices.detach().cpu().numpy(),
            'albedos': albedos.detach().cpu().numpy(),
            'tex': tex.detach().cpu().numpy(),
            'lit': lights.detach().cpu().numpy()
        }
        return single_params
        def closure2():
            nonlocal pose, cam, exp, shape, tex, lights
            nonlocal albedos, step, vertices, trans_vertices
            nonlocal prev_exp, prev_shape, prev_cam, prev_tex, prev_lits, prev_pose
            if torch.is_grad_enabled():
                e_opt.zero_grad()
            losses = {}
            p, e, s, c = self.config.prep_param(pose, exp, shape, cam)
            vertices, landmarks2d, landmarks3d = self.flame(
                shape_params=s,
                expression_params=e,
                pose_params=p,
                cam_params=c)
            trans_vertices = util.batch_orth_proj(vertices, cam)
            trans_vertices[..., 1:] = -trans_vertices[..., 1:]
            landmarks2d = util.batch_orth_proj(landmarks2d, cam)
            landmarks2d[..., 1:] = -landmarks2d[..., 1:]
            landmarks3d = util.batch_orth_proj(landmarks3d, cam)
            landmarks3d[..., 1:] = -landmarks3d[..., 1:]

            losses['landmark'] = weighted_l2_distance(
                landmarks2d[:, :, :2], gt_landmark[:, :, :2],
                self.weights68) * self.config.w_lmks
            losses['shape_reg'] = (torch.sum(shape**2) /
                                   2) * self.config.w_shape_reg  # *1e-4
            losses['expression_reg'] = (torch.sum(exp**2) /
                                        2) * self.config.w_expr_reg  # *1e-4
            losses['pose_reg'] = (torch.sum(pose**2) /
                                  2) * self.config.w_pose_reg

            ## render
            albedos = self.flametex(tex) / 255.
            ops = self.render(vertices, trans_vertices, albedos, lights)
            predicted_images = ops['images']
            losses['photometric_texture'] = (
                image_masks *
                (ops['images'] - images).abs()).mean() * self.config.w_pho

            all_loss = 0.
            for key in losses.keys():
                all_loss = all_loss + losses[key]
            losses['all_loss'] = all_loss
            if math.isnan(all_loss.item()):
                exp = prev_exp
                shape = prev_shape
                pose = prev_pose
                cam = prev_cam
                lights = prev_lits
                tex = prev_tex
                #print(exp, shape, pose, cam)
                print("rebooting...")
                return all_loss
            prev_exp = exp.clone()
            prev_shape = shape.clone()
            prev_pose = pose.clone()
            prev_cam = cam.clone()
            prev_tex = tex.clone()
            prev_lits = lights.clone()

            if all_loss.requires_grad:
                all_loss.backward()
            step += 1
            self.config.post_param(pose, exp, shape, cam)

            loss_info = '----iter: {}, time: {}\n'.format(
                step,
                datetime.datetime.now().strftime('%Y-%m-%d-%H:%M:%S'))
            for key in losses.keys():
                loss_info = loss_info + '{}: {}, '.format(
                    key, float(losses[key]))

            if step % 10 == 0:
                print(loss_info)

            # visualize
            if step % 10 == 0:
                grids = {}
                visind = range(bz)  # [0]
                grids['images'] = torchvision.utils.make_grid(
                    images[visind]).detach().cpu()
                grids['landmarks_gt'] = torchvision.utils.make_grid(
                    util.tensor_vis_landmarks(images[visind],
                                              landmarks[visind]))
                grids['landmarks2d'] = torchvision.utils.make_grid(
                    util.tensor_vis_landmarks(images[visind],
                                              landmarks2d[visind]))
                grids['landmarks3d'] = torchvision.utils.make_grid(
                    util.tensor_vis_landmarks(images[visind],
                                              landmarks3d[visind]))
                grids['albedoimage'] = torchvision.utils.make_grid(
                    (ops['albedo_images'])[visind].detach().cpu())
                grids['render'] = torchvision.utils.make_grid(
                    predicted_images[visind].detach().float().cpu())
                shape_images = self.render.render_shape(
                    vertices, trans_vertices, images)
                grids['shape'] = torchvision.utils.make_grid(
                    F.interpolate(shape_images[visind],
                                  [224, 224])).detach().float().cpu()

                # grids['tex'] = torchvision.utils.make_grid(F.interpolate(albedos[visind], [224, 224])).detach().cpu()
                grid = torch.cat(list(grids.values()), 1)
                grid_image = (grid.numpy().transpose(1, 2, 0).copy() *
                              255)[:, :, [2, 1, 0]]
                grid_image = np.minimum(np.maximum(grid_image, 0),
                                        255).astype(np.uint8)

                cv2.imwrite('{}/{}.jpg'.format(savefolder, step), grid_image)
            return all_loss
        def closure1():
            nonlocal step
            losses = {}
            if torch.is_grad_enabled():
                e_opt_rigid.zero_grad()
            p, e, s, c = self.config.prep_param(pose, exp, shape, cam)
            vertices, landmarks2d, landmarks3d = self.flame(
                shape_params=s,
                expression_params=e,
                pose_params=p,
                cam_params=c)
            trans_vertices = util.batch_orth_proj(vertices, cam)
            trans_vertices[..., 1:] = -trans_vertices[..., 1:]
            landmarks2d = util.batch_orth_proj(landmarks2d, cam)
            landmarks2d[..., 1:] = -landmarks2d[..., 1:]
            landmarks3d = util.batch_orth_proj(landmarks3d, cam)
            landmarks3d[..., 1:] = -landmarks3d[..., 1:]

            losses['landmark'] = weighted_l2_distance(
                landmarks2d[:, 17:, :2], gt_landmark[:, 17:, :2],
                self.weights51) * self.config.w_lmks

            all_loss = 0.
            for key in losses.keys():
                all_loss = all_loss + losses[key]
            losses['all_loss'] = all_loss
            if all_loss.requires_grad:
                all_loss.backward()
            step += 1
            self.config.post_param(pose, exp, shape, cam)

            loss_info = '----iter: {}, time: {}\n'.format(
                step,
                datetime.datetime.now().strftime('%Y-%m-%d-%H:%M:%S'))
            for key in losses.keys():
                loss_info = loss_info + '{}: {}, '.format(
                    key, float(losses[key]))
            if step % 10 == 0:
                print(loss_info)

            if step % 10 == 0:
                grids = {}
                visind = range(bz)  # [0]
                grids['images'] = torchvision.utils.make_grid(
                    images[visind]).detach().cpu()
                grids['landmarks_gt'] = torchvision.utils.make_grid(
                    util.tensor_vis_landmarks(images[visind],
                                              landmarks[visind]))
                grids['landmarks2d'] = torchvision.utils.make_grid(
                    util.tensor_vis_landmarks(images[visind],
                                              landmarks2d[visind]))
                grids['landmarks3d'] = torchvision.utils.make_grid(
                    util.tensor_vis_landmarks(images[visind],
                                              landmarks3d[visind]))

                grid = torch.cat(list(grids.values()), 1)
                grid_image = (grid.numpy().transpose(1, 2, 0).copy() *
                              255)[:, :, [2, 1, 0]]
                grid_image = np.minimum(np.maximum(grid_image, 0),
                                        255).astype(np.uint8)
                cv2.imwrite('{}/{}.jpg'.format(savefolder, step), grid_image)

            return all_loss