Пример #1
0
    def _calc_detail_info(self, param, kp3d_24=False):
        cam = param[:, 0:3].contiguous()
        pose = param[:, 3:75].contiguous()
        shape = param[:, 75:].contiguous()
        verts, j3d, Rs = self.smpl(beta=shape, param=pose, get_skin=True)
        projected_j2d = util.batch_orth_proj(j3d.clone(), cam, mode='2d')
        j3d = util.batch_orth_proj(j3d.clone(), cam, mode='j3d')
        verts_camed = util.batch_orth_proj(verts, cam, mode='v3d')
        if kp3d_24:
            _, j3d, _ = self.smpl(beta=shape, param=pose, get_org_joints=True)
            j3d = batch_orth_proj(j3d.clone(), cam, mode='3d')

        return ((cam, pose, shape), verts, projected_j2d, j3d, Rs, verts_camed,
                j3d)
    def render_tex_and_normal(self, shapecode, expcode, posecode, texcode,
                              lightcode, cam):
        verts, _, _ = self.flame(shape_params=shapecode,
                                 expression_params=expcode,
                                 pose_params=posecode)
        trans_verts = util.batch_orth_proj(verts, cam)
        trans_verts[:, :, 1:] = -trans_verts[:, :, 1:]

        albedos = self.flametex(texcode)
        rendering_results = self.render(verts,
                                        trans_verts,
                                        albedos,
                                        lights=lightcode)
        textured_images, normals = rendering_results[
            'images'], rendering_results['normals']
        normal_images = self.render.render_normal(trans_verts, normals)
        return textured_images, normal_images
Пример #3
0
    def optimize(self, images, landmarks, image_masks, video_writer):
        bz = images.shape[0]
        shape = nn.Parameter(torch.zeros(bz, cfg.shape_params).float().to(self.device))
        tex = nn.Parameter(torch.zeros(bz, cfg.tex_params).float().to(self.device))
        exp = nn.Parameter(torch.zeros(bz, cfg.expression_params).float().to(self.device))
        pose = nn.Parameter(torch.zeros(bz, cfg.pose_params).float().to(self.device))
        cam = torch.zeros(bz, cfg.camera_params)
        cam[:, 0] = 5.
        cam = nn.Parameter(cam.float().to(self.device))
        lights = nn.Parameter(torch.zeros(bz, 9, 3).float().to(self.device))
        e_opt = torch.optim.Adam(
            [shape, exp, pose, cam, tex, lights],
            lr=cfg.e_lr,
            weight_decay=cfg.e_wd
        )

        gt_landmark = landmarks

        # non-rigid fitting of all the parameters with 68 face landmarks, photometric loss and regularization terms.
        all_train_iter = 0
        all_train_iters = []
        photometric_loss = []
        for k in range(cfg.max_iter):
            losses = {}
            vertices, landmarks2d, landmarks3d = self.flame(shape_params=shape, expression_params=exp, pose_params=pose)
            trans_vertices = util.batch_orth_proj(vertices, cam)
            trans_vertices[..., 1:] = - trans_vertices[..., 1:]
            landmarks2d = util.batch_orth_proj(landmarks2d, cam)
            landmarks2d[..., 1:] = - landmarks2d[..., 1:]
            landmarks3d = util.batch_orth_proj(landmarks3d, cam)
            landmarks3d[..., 1:] = - landmarks3d[..., 1:]
            losses['landmark'] = util.l2_distance(landmarks2d[:, :, :2], gt_landmark[:, :, :2])

            # render
            albedos = self.flametex(tex) / 255.
            ops = self.render(vertices, trans_vertices, albedos, lights)
            predicted_images = ops['images']
            # losses['photometric_texture'] = (image_masks * (ops['images'] - images).abs()).mean() * config.w_pho
            losses['photometric_texture'] = F.smooth_l1_loss(image_masks * ops['images'],
                                                             image_masks * images) * cfg.w_pho

            all_loss = 0.
            for key in losses.keys():
                all_loss = all_loss + losses[key]
            losses['all_loss'] = all_loss
            e_opt.zero_grad()
            all_loss.backward()
            e_opt.step()

            loss_info = '----iter: {}, time: {}\n'.format(k, datetime.datetime.now().strftime('%Y-%m-%d-%H:%M:%S'))
            for key in losses.keys():
                loss_info = loss_info + '{}: {}, '.format(key, float(losses[key]))

            if k % 10 == 0:
                all_train_iter += 10
                all_train_iters.append(all_train_iter)
                photometric_loss.append(losses['photometric_texture'])
                print(loss_info)

                grids = {}
                visind = range(bz)  # [0]
                grids['images'] = torchvision.utils.make_grid(images[visind]).detach().cpu()
                grids['landmarks_gt'] = torchvision.utils.make_grid(
                    util.tensor_vis_landmarks(images[visind], landmarks[visind]))
                grids['landmarks2d'] = torchvision.utils.make_grid(
                    util.tensor_vis_landmarks(images[visind], landmarks2d[visind]))
                grids['landmarks3d'] = torchvision.utils.make_grid(
                    util.tensor_vis_landmarks(images[visind], landmarks3d[visind]))
                grids['albedoimage'] = torchvision.utils.make_grid(
                    (ops['albedo_images'])[visind].detach().cpu())
                grids['render'] = torchvision.utils.make_grid(predicted_images[visind].detach().float().cpu())
                shape_images = self.render.render_shape(vertices, trans_vertices, images)
                grids['shape'] = torchvision.utils.make_grid(
                    F.interpolate(shape_images[visind], [224, 224])).detach().float().cpu()

                # grids['tex'] = torchvision.utils.make_grid(F.interpolate(albedos[visind], [224, 224])).detach().cpu()
                grid = torch.cat(list(grids.values()), 1)
                grid_image = (grid.numpy().transpose(1, 2, 0).copy() * 255)[:, :, [2, 1, 0]]
                grid_image = np.minimum(np.maximum(grid_image, 0), 255).astype(np.uint8)
                video_writer.write(grid_image)

        single_params = {
            'shape': shape.detach().cpu().numpy(),
            'exp': exp.detach().cpu().numpy(),
            'pose': pose.detach().cpu().numpy(),
            'cam': cam.detach().cpu().numpy(),
            'verts': trans_vertices.detach().cpu().numpy(),
            'albedos': albedos.detach().cpu().numpy(),
            'tex': tex.detach().cpu().numpy(),
            'lit': lights.detach().cpu().numpy()
        }
        util.draw_train_process("training", all_train_iters, photometric_loss, 'photometric loss')
        # np.save("./test_results/model.npy", single_params)
        return single_params
Пример #4
0
    def optimize(self, images, landmarks, image_masks, all_param, video_writer,
                 first_flag):
        shape_para, tex_para, exp_para, pose_para, cam_para, lights_para = all_param
        e_opt = torch.optim.Adam(
            [shape_para, exp_para, pose_para, cam_para, tex_para, lights_para],
            lr=cfg.e_lr,
            weight_decay=cfg.e_wd)
        d_opt = torch.optim.Adam([shape_para, exp_para, pose_para, cam_para],
                                 lr=cfg.e_lr,
                                 weight_decay=cfg.e_wd)

        gt_landmark = landmarks
        max_iter = 50
        if first_flag:
            max_iter = cfg.max_iter

        tmp_predict = torch.squeeze(images)
        for k in range(0, max_iter):
            losses = {}
            vertices, landmarks2d, landmarks3d = self.flame(
                shape_params=shape_para,
                expression_params=exp_para,
                pose_params=pose_para)
            trans_vertices = util.batch_orth_proj(vertices, cam_para)
            trans_vertices[..., 1:] = -trans_vertices[..., 1:]
            landmarks2d = util.batch_orth_proj(landmarks2d, cam_para)
            landmarks2d[..., 1:] = -landmarks2d[..., 1:]
            landmarks3d = util.batch_orth_proj(landmarks3d, cam_para)
            landmarks3d[..., 1:] = -landmarks3d[..., 1:]

            losses['landmark'] = util.l2_distance(landmarks2d[:, :, :2],
                                                  gt_landmark[:, :, :2])

            # render
            albedos = self.flametex(tex_para) / 255.
            ops = self.render(vertices, trans_vertices, albedos, lights_para)
            tmp_predict = torchvision.utils.make_grid(
                ops['images'][0].detach().float().cpu())
            # losses['photometric_texture'] = (image_masks * (ops['images'] - images).abs()).mean() * config.w_pho
            if first_flag:
                losses['photometric_texture'] = F.smooth_l1_loss(
                    image_masks * ops['images'],
                    image_masks * images) * cfg.w_pho

            all_loss = 0.
            for key in losses.keys():
                all_loss = all_loss + losses[key]
            losses['all_loss'] = all_loss
            if first_flag:
                e_opt.zero_grad()
                all_loss.backward()
                e_opt.step()
            else:
                d_opt.zero_grad()
                all_loss.backward()
                d_opt.step()
            loss_info = '----iter: {}, time: {}\n'.format(
                k,
                datetime.datetime.now().strftime('%Y-%m-%d-%H:%M:%S'))
            for key in losses.keys():
                loss_info = loss_info + '{}: {}, '.format(
                    key, float(losses[key]))
            print(loss_info)

        # tmp_predict = torchvision.utils.make_grid(ops['images'][0].detach().float().cpu())
        tmp_predict = (tmp_predict.numpy().transpose(1, 2, 0).copy() *
                       255)[:, :, [2, 1, 0]]
        tmp_predict = np.minimum(np.maximum(tmp_predict, 0),
                                 255).astype(np.uint8)

        tmp_image = torchvision.utils.make_grid(
            images[0].detach().float().cpu())
        tmp_image = (tmp_image.numpy().transpose(1, 2, 0).copy() *
                     255)[:, :, [2, 1, 0]]
        tmp_image = np.minimum(np.maximum(tmp_image, 0), 255).astype(np.uint8)
        combine = np.concatenate((tmp_predict, tmp_image), axis=1)
        cv2.imshow("tmp_image", combine)
        cv2.waitKey(1)
        video_writer.write(combine)
        return [
            shape_para, tex_para, exp_para, pose_para, cam_para, lights_para
        ]
    def optimize(self, images, landmarks, image_masks, savefolder=None):
        bz = images.shape[0]
        shape = nn.Parameter(
            torch.zeros(bz, cfg.shape_params).float().to(self.device))
        tex = nn.Parameter(
            torch.zeros(bz, cfg.tex_params).float().to(self.device))
        exp = nn.Parameter(
            torch.zeros(bz, cfg.expression_params).float().to(self.device))
        pose = nn.Parameter(
            torch.zeros(bz, cfg.pose_params).float().to(self.device))
        cam = torch.zeros(bz, cfg.camera_params)
        cam[:, 0] = 5.
        cam = nn.Parameter(cam.float().to(self.device))
        lights = nn.Parameter(torch.zeros(bz, 9, 3).float().to(self.device))
        e_opt = torch.optim.Adam([shape, exp, pose, cam, tex, lights],
                                 lr=cfg.e_lr,
                                 weight_decay=cfg.e_wd)
        e_opt_rigid = torch.optim.Adam([pose, cam],
                                       lr=cfg.e_lr,
                                       weight_decay=cfg.e_wd)

        gt_landmark = landmarks

        # rigid fitting of pose and camera with 51 static face landmarks,
        # this is due to the non-differentiable attribute of contour landmarks trajectory
        for k in range(200):
            losses = {}
            vertices, landmarks2d, landmarks3d = self.flame(
                shape_params=shape, expression_params=exp, pose_params=pose)
            trans_vertices = util.batch_orth_proj(vertices, cam)
            trans_vertices[..., 1:] = -trans_vertices[..., 1:]
            landmarks2d = util.batch_orth_proj(landmarks2d, cam)
            landmarks2d[..., 1:] = -landmarks2d[..., 1:]
            landmarks3d = util.batch_orth_proj(landmarks3d, cam)
            landmarks3d[..., 1:] = -landmarks3d[..., 1:]

            losses['landmark'] = util.l2_distance(
                landmarks2d[:, 17:, :2], gt_landmark[:, 17:, :2]) * cfg.w_lmks

            all_loss = 0.
            for key in losses.keys():
                all_loss = all_loss + losses[key]
            losses['all_loss'] = all_loss
            e_opt_rigid.zero_grad()
            all_loss.backward()
            e_opt_rigid.step()

            loss_info = '----iter: {}, time: {}\n'.format(
                k,
                datetime.datetime.now().strftime('%Y-%m-%d-%H:%M:%S'))
            for key in losses.keys():
                loss_info = loss_info + '{}: {}, '.format(
                    key, float(losses[key]))
            if k % 10 == 0:
                print(loss_info)

            if k % 10 == 0:
                grids = {}
                visind = range(bz)  # [0]
                grids['images'] = torchvision.utils.make_grid(
                    images[visind]).detach().cpu()
                grids['landmarks_gt'] = torchvision.utils.make_grid(
                    util.tensor_vis_landmarks(images[visind],
                                              landmarks[visind]))
                grids['landmarks2d'] = torchvision.utils.make_grid(
                    util.tensor_vis_landmarks(images[visind],
                                              landmarks2d[visind]))
                grids['landmarks3d'] = torchvision.utils.make_grid(
                    util.tensor_vis_landmarks(images[visind],
                                              landmarks3d[visind]))

                grid = torch.cat(list(grids.values()), 1)
                grid_image = (grid.numpy().transpose(1, 2, 0).copy() *
                              255)[:, :, [2, 1, 0]]
                grid_image = np.minimum(np.maximum(grid_image, 0),
                                        255).astype(np.uint8)
                cv2.imwrite('{}/{}.jpg'.format(savefolder, k), grid_image)

        # non-rigid fitting of all the parameters with 68 face landmarks, photometric loss and regularization terms.
        for k in range(200, 1000):
            losses = {}
            vertices, landmarks2d, landmarks3d = self.flame(
                shape_params=shape, expression_params=exp, pose_params=pose)
            trans_vertices = util.batch_orth_proj(vertices, cam)
            trans_vertices[..., 1:] = -trans_vertices[..., 1:]
            landmarks2d = util.batch_orth_proj(landmarks2d, cam)
            landmarks2d[..., 1:] = -landmarks2d[..., 1:]
            landmarks3d = util.batch_orth_proj(landmarks3d, cam)
            landmarks3d[..., 1:] = -landmarks3d[..., 1:]

            losses['landmark'] = util.l2_distance(
                landmarks2d[:, :, :2], gt_landmark[:, :, :2]) * cfg.w_lmks
            losses['shape_reg'] = (torch.sum(shape**2) /
                                   2) * cfg.w_shape_reg  # *1e-4
            losses['expression_reg'] = (torch.sum(exp**2) /
                                        2) * cfg.w_expr_reg  # *1e-4
            losses['pose_reg'] = (torch.sum(pose**2) / 2) * cfg.w_pose_reg

            ## render
            albedos = self.flametex(tex) / 255.
            ops = self.render(vertices, trans_vertices, albedos, lights)
            predicted_images = ops['images']
            losses['photometric_texture'] = (
                image_masks *
                (ops['images'] - images).abs()).mean() * cfg.w_pho

            all_loss = 0.
            for key in losses.keys():
                all_loss = all_loss + losses[key]
            losses['all_loss'] = all_loss
            e_opt.zero_grad()
            all_loss.backward()
            e_opt.step()

            loss_info = '----iter: {}, time: {}\n'.format(
                k,
                datetime.datetime.now().strftime('%Y-%m-%d-%H:%M:%S'))
            for key in losses.keys():
                loss_info = loss_info + '{}: {}, '.format(
                    key, float(losses[key]))

            if k % 10 == 0:
                print(loss_info)

            # visualize
            if k % 10 == 0:
                grids = {}
                visind = range(bz)  # [0]
                grids['images'] = torchvision.utils.make_grid(
                    images[visind]).detach().cpu()
                grids['landmarks_gt'] = torchvision.utils.make_grid(
                    util.tensor_vis_landmarks(images[visind],
                                              landmarks[visind]))
                grids['landmarks2d'] = torchvision.utils.make_grid(
                    util.tensor_vis_landmarks(images[visind],
                                              landmarks2d[visind]))
                grids['landmarks3d'] = torchvision.utils.make_grid(
                    util.tensor_vis_landmarks(images[visind],
                                              landmarks3d[visind]))
                grids['albedoimage'] = torchvision.utils.make_grid(
                    (ops['albedo_images'])[visind].detach().cpu())
                grids['render'] = torchvision.utils.make_grid(
                    predicted_images[visind].detach().float().cpu())
                shape_images = self.render.render_shape(
                    vertices, trans_vertices, images)
                grids['shape'] = torchvision.utils.make_grid(
                    F.interpolate(shape_images[visind],
                                  [224, 224])).detach().float().cpu()

                # grids['tex'] = torchvision.utils.make_grid(F.interpolate(albedos[visind], [224, 224])).detach().cpu()
                grid = torch.cat(list(grids.values()), 1)
                grid_image = (grid.numpy().transpose(1, 2, 0).copy() *
                              255)[:, :, [2, 1, 0]]
                grid_image = np.minimum(np.maximum(grid_image, 0),
                                        255).astype(np.uint8)

                cv2.imwrite('{}/{}.jpg'.format(savefolder, k), grid_image)

        single_params = {
            'shape': shape.detach().cpu().numpy(),
            'exp': exp.detach().cpu().numpy(),
            'pose': pose.detach().cpu().numpy(),
            'cam': cam.detach().cpu().numpy(),
            'verts': trans_vertices.detach().cpu().numpy(),
            'albedos': albedos.detach().cpu().numpy(),
            'tex': tex.detach().cpu().numpy(),
            'lit': lights.detach().cpu().numpy()
        }
        return single_params
Пример #6
0
    def decode(self, codedict, epoch):
        images = codedict['images']
        batch_size = images.shape[0]
        
        ## decode
        verts, landmarks2d, landmarks3d = self.flame(shape_params=codedict['shape'], \
            expression_params=codedict['exp'], pose_params=codedict['pose'])
        if self.config.model.use_tex:
            albedo = self.flametex(codedict['tex'])
        else:
            albedo = torch.zeros([batch_size, 3, self.uv_size, self.uv_size], device=images.device) 

        ## projection
        landmarks2d = util.batch_orth_proj(landmarks2d, codedict['cam'])[:,:,:2]
        landmarks2d[:,:,1:] = -landmarks2d[:,:,1:]; 
        landmarks2d = landmarks2d*self.image_size/2 + self.image_size/2
        landmarks2d /= (self.image_size - 1)
        landmarks3d = util.batch_orth_proj(landmarks3d, codedict['cam'])
        landmarks3d[:,:,1:] = -landmarks3d[:,:,1:]
        landmarks3d = landmarks3d*self.image_size/2 + self.image_size/2
        landmarks3d /= (self.image_size - 1)
        trans_verts = util.batch_orth_proj(verts, codedict['cam'])
        trans_verts[:,:,1:] = -trans_verts[:,:,1:]
        # trans_verts = trans_verts*self.image_size/2 + self.image_size/2
        # trans_verts /= (self.image_size - 1)
        normals = util.vertex_normals(verts, self.faces.expand(batch_size, -1, -1).to(self.device))

        output = {'albedo': albedo, 'verts': verts, 'trans_verts': trans_verts, \
                    'landmarks2d': landmarks2d, 'landmarks3d': landmarks3d, 'normals': normals}

        # shape consistency
        if 'coarse' in self.mode and epoch>self.epoch_phase:
            verts, landmarks2d, landmarks3d = self.flame(shape_params=codedict['shape_shuffle'], \
                expression_params=codedict['exp'], pose_params=codedict['pose'])

            ## projection
            landmarks2d = util.batch_orth_proj(landmarks2d, codedict['cam'])[:,:,:2]
            landmarks2d[:,:,1:] = -landmarks2d[:,:,1:]
            landmarks2d = landmarks2d*self.image_size/2 + self.image_size/2
            landmarks2d /= (self.image_size - 1)
            landmarks3d = util.batch_orth_proj(landmarks3d, codedict['cam'])
            landmarks3d[:,:,1:] = -landmarks3d[:,:,1:]
            landmarks3d = landmarks3d*self.image_size/2 + self.image_size/2
            landmarks3d /= (self.image_size - 1)
            trans_verts = util.batch_orth_proj(verts, codedict['cam'])
            trans_verts[:,:,1:] = -trans_verts[:,:,1:]
            # trans_verts = trans_verts*self.image_size/2 + self.image_size/2
            # trans_verts /= (self.image_size - 1)
            # normals = util.vertex_normals(verts, self.faces.expand(batch_size, -1, -1))

            output['landmarks2d_shuffle'] = landmarks2d
            output['landmarks3d_shuffle'] = landmarks3d
            output['verts_shuffle'] = verts
            output['trans_verts_shuffle'] = trans_verts
            # output['normals_shuffle'] = normals

        if self.mode == 'train_detail':
            uv_z = self.D_detail(torch.cat([codedict['pose'][:,3:], codedict['exp'], \
                codedict['detail']], dim=1))
            output['displacement_map'] = uv_z+self.fixed_uv_dis[None,None,:,:]
            dense_vertices, dense_faces = displacement2vertex(uv_z, verts, normals, self.unsupervised_losses_conductor.render)
            uv_detail_normals = displacement2normal(uv_z, verts, normals, self.unsupervised_losses_conductor.render)
            dense_trans_verts = util.batch_orth_proj(dense_vertices, codedict['cam'])
            dense_trans_verts[:,:,1:] = -dense_trans_verts[:,:,1:]

            output['detail_verts'] = dense_vertices
            output['detail_trans_verts'] = dense_trans_verts
            output['detail_faces'] = dense_faces
            output['uv_detail_normals'] = uv_detail_normals

        return output