def _calc_detail_info(self, param, kp3d_24=False): cam = param[:, 0:3].contiguous() pose = param[:, 3:75].contiguous() shape = param[:, 75:].contiguous() verts, j3d, Rs = self.smpl(beta=shape, param=pose, get_skin=True) projected_j2d = util.batch_orth_proj(j3d.clone(), cam, mode='2d') j3d = util.batch_orth_proj(j3d.clone(), cam, mode='j3d') verts_camed = util.batch_orth_proj(verts, cam, mode='v3d') if kp3d_24: _, j3d, _ = self.smpl(beta=shape, param=pose, get_org_joints=True) j3d = batch_orth_proj(j3d.clone(), cam, mode='3d') return ((cam, pose, shape), verts, projected_j2d, j3d, Rs, verts_camed, j3d)
def render_tex_and_normal(self, shapecode, expcode, posecode, texcode, lightcode, cam): verts, _, _ = self.flame(shape_params=shapecode, expression_params=expcode, pose_params=posecode) trans_verts = util.batch_orth_proj(verts, cam) trans_verts[:, :, 1:] = -trans_verts[:, :, 1:] albedos = self.flametex(texcode) rendering_results = self.render(verts, trans_verts, albedos, lights=lightcode) textured_images, normals = rendering_results[ 'images'], rendering_results['normals'] normal_images = self.render.render_normal(trans_verts, normals) return textured_images, normal_images
def optimize(self, images, landmarks, image_masks, video_writer): bz = images.shape[0] shape = nn.Parameter(torch.zeros(bz, cfg.shape_params).float().to(self.device)) tex = nn.Parameter(torch.zeros(bz, cfg.tex_params).float().to(self.device)) exp = nn.Parameter(torch.zeros(bz, cfg.expression_params).float().to(self.device)) pose = nn.Parameter(torch.zeros(bz, cfg.pose_params).float().to(self.device)) cam = torch.zeros(bz, cfg.camera_params) cam[:, 0] = 5. cam = nn.Parameter(cam.float().to(self.device)) lights = nn.Parameter(torch.zeros(bz, 9, 3).float().to(self.device)) e_opt = torch.optim.Adam( [shape, exp, pose, cam, tex, lights], lr=cfg.e_lr, weight_decay=cfg.e_wd ) gt_landmark = landmarks # non-rigid fitting of all the parameters with 68 face landmarks, photometric loss and regularization terms. all_train_iter = 0 all_train_iters = [] photometric_loss = [] for k in range(cfg.max_iter): losses = {} vertices, landmarks2d, landmarks3d = self.flame(shape_params=shape, expression_params=exp, pose_params=pose) trans_vertices = util.batch_orth_proj(vertices, cam) trans_vertices[..., 1:] = - trans_vertices[..., 1:] landmarks2d = util.batch_orth_proj(landmarks2d, cam) landmarks2d[..., 1:] = - landmarks2d[..., 1:] landmarks3d = util.batch_orth_proj(landmarks3d, cam) landmarks3d[..., 1:] = - landmarks3d[..., 1:] losses['landmark'] = util.l2_distance(landmarks2d[:, :, :2], gt_landmark[:, :, :2]) # render albedos = self.flametex(tex) / 255. ops = self.render(vertices, trans_vertices, albedos, lights) predicted_images = ops['images'] # losses['photometric_texture'] = (image_masks * (ops['images'] - images).abs()).mean() * config.w_pho losses['photometric_texture'] = F.smooth_l1_loss(image_masks * ops['images'], image_masks * images) * cfg.w_pho all_loss = 0. for key in losses.keys(): all_loss = all_loss + losses[key] losses['all_loss'] = all_loss e_opt.zero_grad() all_loss.backward() e_opt.step() loss_info = '----iter: {}, time: {}\n'.format(k, datetime.datetime.now().strftime('%Y-%m-%d-%H:%M:%S')) for key in losses.keys(): loss_info = loss_info + '{}: {}, '.format(key, float(losses[key])) if k % 10 == 0: all_train_iter += 10 all_train_iters.append(all_train_iter) photometric_loss.append(losses['photometric_texture']) print(loss_info) grids = {} visind = range(bz) # [0] grids['images'] = torchvision.utils.make_grid(images[visind]).detach().cpu() grids['landmarks_gt'] = torchvision.utils.make_grid( util.tensor_vis_landmarks(images[visind], landmarks[visind])) grids['landmarks2d'] = torchvision.utils.make_grid( util.tensor_vis_landmarks(images[visind], landmarks2d[visind])) grids['landmarks3d'] = torchvision.utils.make_grid( util.tensor_vis_landmarks(images[visind], landmarks3d[visind])) grids['albedoimage'] = torchvision.utils.make_grid( (ops['albedo_images'])[visind].detach().cpu()) grids['render'] = torchvision.utils.make_grid(predicted_images[visind].detach().float().cpu()) shape_images = self.render.render_shape(vertices, trans_vertices, images) grids['shape'] = torchvision.utils.make_grid( F.interpolate(shape_images[visind], [224, 224])).detach().float().cpu() # grids['tex'] = torchvision.utils.make_grid(F.interpolate(albedos[visind], [224, 224])).detach().cpu() grid = torch.cat(list(grids.values()), 1) grid_image = (grid.numpy().transpose(1, 2, 0).copy() * 255)[:, :, [2, 1, 0]] grid_image = np.minimum(np.maximum(grid_image, 0), 255).astype(np.uint8) video_writer.write(grid_image) single_params = { 'shape': shape.detach().cpu().numpy(), 'exp': exp.detach().cpu().numpy(), 'pose': pose.detach().cpu().numpy(), 'cam': cam.detach().cpu().numpy(), 'verts': trans_vertices.detach().cpu().numpy(), 'albedos': albedos.detach().cpu().numpy(), 'tex': tex.detach().cpu().numpy(), 'lit': lights.detach().cpu().numpy() } util.draw_train_process("training", all_train_iters, photometric_loss, 'photometric loss') # np.save("./test_results/model.npy", single_params) return single_params
def optimize(self, images, landmarks, image_masks, all_param, video_writer, first_flag): shape_para, tex_para, exp_para, pose_para, cam_para, lights_para = all_param e_opt = torch.optim.Adam( [shape_para, exp_para, pose_para, cam_para, tex_para, lights_para], lr=cfg.e_lr, weight_decay=cfg.e_wd) d_opt = torch.optim.Adam([shape_para, exp_para, pose_para, cam_para], lr=cfg.e_lr, weight_decay=cfg.e_wd) gt_landmark = landmarks max_iter = 50 if first_flag: max_iter = cfg.max_iter tmp_predict = torch.squeeze(images) for k in range(0, max_iter): losses = {} vertices, landmarks2d, landmarks3d = self.flame( shape_params=shape_para, expression_params=exp_para, pose_params=pose_para) trans_vertices = util.batch_orth_proj(vertices, cam_para) trans_vertices[..., 1:] = -trans_vertices[..., 1:] landmarks2d = util.batch_orth_proj(landmarks2d, cam_para) landmarks2d[..., 1:] = -landmarks2d[..., 1:] landmarks3d = util.batch_orth_proj(landmarks3d, cam_para) landmarks3d[..., 1:] = -landmarks3d[..., 1:] losses['landmark'] = util.l2_distance(landmarks2d[:, :, :2], gt_landmark[:, :, :2]) # render albedos = self.flametex(tex_para) / 255. ops = self.render(vertices, trans_vertices, albedos, lights_para) tmp_predict = torchvision.utils.make_grid( ops['images'][0].detach().float().cpu()) # losses['photometric_texture'] = (image_masks * (ops['images'] - images).abs()).mean() * config.w_pho if first_flag: losses['photometric_texture'] = F.smooth_l1_loss( image_masks * ops['images'], image_masks * images) * cfg.w_pho all_loss = 0. for key in losses.keys(): all_loss = all_loss + losses[key] losses['all_loss'] = all_loss if first_flag: e_opt.zero_grad() all_loss.backward() e_opt.step() else: d_opt.zero_grad() all_loss.backward() d_opt.step() loss_info = '----iter: {}, time: {}\n'.format( k, datetime.datetime.now().strftime('%Y-%m-%d-%H:%M:%S')) for key in losses.keys(): loss_info = loss_info + '{}: {}, '.format( key, float(losses[key])) print(loss_info) # tmp_predict = torchvision.utils.make_grid(ops['images'][0].detach().float().cpu()) tmp_predict = (tmp_predict.numpy().transpose(1, 2, 0).copy() * 255)[:, :, [2, 1, 0]] tmp_predict = np.minimum(np.maximum(tmp_predict, 0), 255).astype(np.uint8) tmp_image = torchvision.utils.make_grid( images[0].detach().float().cpu()) tmp_image = (tmp_image.numpy().transpose(1, 2, 0).copy() * 255)[:, :, [2, 1, 0]] tmp_image = np.minimum(np.maximum(tmp_image, 0), 255).astype(np.uint8) combine = np.concatenate((tmp_predict, tmp_image), axis=1) cv2.imshow("tmp_image", combine) cv2.waitKey(1) video_writer.write(combine) return [ shape_para, tex_para, exp_para, pose_para, cam_para, lights_para ]
def optimize(self, images, landmarks, image_masks, savefolder=None): bz = images.shape[0] shape = nn.Parameter( torch.zeros(bz, cfg.shape_params).float().to(self.device)) tex = nn.Parameter( torch.zeros(bz, cfg.tex_params).float().to(self.device)) exp = nn.Parameter( torch.zeros(bz, cfg.expression_params).float().to(self.device)) pose = nn.Parameter( torch.zeros(bz, cfg.pose_params).float().to(self.device)) cam = torch.zeros(bz, cfg.camera_params) cam[:, 0] = 5. cam = nn.Parameter(cam.float().to(self.device)) lights = nn.Parameter(torch.zeros(bz, 9, 3).float().to(self.device)) e_opt = torch.optim.Adam([shape, exp, pose, cam, tex, lights], lr=cfg.e_lr, weight_decay=cfg.e_wd) e_opt_rigid = torch.optim.Adam([pose, cam], lr=cfg.e_lr, weight_decay=cfg.e_wd) gt_landmark = landmarks # rigid fitting of pose and camera with 51 static face landmarks, # this is due to the non-differentiable attribute of contour landmarks trajectory for k in range(200): losses = {} vertices, landmarks2d, landmarks3d = self.flame( shape_params=shape, expression_params=exp, pose_params=pose) trans_vertices = util.batch_orth_proj(vertices, cam) trans_vertices[..., 1:] = -trans_vertices[..., 1:] landmarks2d = util.batch_orth_proj(landmarks2d, cam) landmarks2d[..., 1:] = -landmarks2d[..., 1:] landmarks3d = util.batch_orth_proj(landmarks3d, cam) landmarks3d[..., 1:] = -landmarks3d[..., 1:] losses['landmark'] = util.l2_distance( landmarks2d[:, 17:, :2], gt_landmark[:, 17:, :2]) * cfg.w_lmks all_loss = 0. for key in losses.keys(): all_loss = all_loss + losses[key] losses['all_loss'] = all_loss e_opt_rigid.zero_grad() all_loss.backward() e_opt_rigid.step() loss_info = '----iter: {}, time: {}\n'.format( k, datetime.datetime.now().strftime('%Y-%m-%d-%H:%M:%S')) for key in losses.keys(): loss_info = loss_info + '{}: {}, '.format( key, float(losses[key])) if k % 10 == 0: print(loss_info) if k % 10 == 0: grids = {} visind = range(bz) # [0] grids['images'] = torchvision.utils.make_grid( images[visind]).detach().cpu() grids['landmarks_gt'] = torchvision.utils.make_grid( util.tensor_vis_landmarks(images[visind], landmarks[visind])) grids['landmarks2d'] = torchvision.utils.make_grid( util.tensor_vis_landmarks(images[visind], landmarks2d[visind])) grids['landmarks3d'] = torchvision.utils.make_grid( util.tensor_vis_landmarks(images[visind], landmarks3d[visind])) grid = torch.cat(list(grids.values()), 1) grid_image = (grid.numpy().transpose(1, 2, 0).copy() * 255)[:, :, [2, 1, 0]] grid_image = np.minimum(np.maximum(grid_image, 0), 255).astype(np.uint8) cv2.imwrite('{}/{}.jpg'.format(savefolder, k), grid_image) # non-rigid fitting of all the parameters with 68 face landmarks, photometric loss and regularization terms. for k in range(200, 1000): losses = {} vertices, landmarks2d, landmarks3d = self.flame( shape_params=shape, expression_params=exp, pose_params=pose) trans_vertices = util.batch_orth_proj(vertices, cam) trans_vertices[..., 1:] = -trans_vertices[..., 1:] landmarks2d = util.batch_orth_proj(landmarks2d, cam) landmarks2d[..., 1:] = -landmarks2d[..., 1:] landmarks3d = util.batch_orth_proj(landmarks3d, cam) landmarks3d[..., 1:] = -landmarks3d[..., 1:] losses['landmark'] = util.l2_distance( landmarks2d[:, :, :2], gt_landmark[:, :, :2]) * cfg.w_lmks losses['shape_reg'] = (torch.sum(shape**2) / 2) * cfg.w_shape_reg # *1e-4 losses['expression_reg'] = (torch.sum(exp**2) / 2) * cfg.w_expr_reg # *1e-4 losses['pose_reg'] = (torch.sum(pose**2) / 2) * cfg.w_pose_reg ## render albedos = self.flametex(tex) / 255. ops = self.render(vertices, trans_vertices, albedos, lights) predicted_images = ops['images'] losses['photometric_texture'] = ( image_masks * (ops['images'] - images).abs()).mean() * cfg.w_pho all_loss = 0. for key in losses.keys(): all_loss = all_loss + losses[key] losses['all_loss'] = all_loss e_opt.zero_grad() all_loss.backward() e_opt.step() loss_info = '----iter: {}, time: {}\n'.format( k, datetime.datetime.now().strftime('%Y-%m-%d-%H:%M:%S')) for key in losses.keys(): loss_info = loss_info + '{}: {}, '.format( key, float(losses[key])) if k % 10 == 0: print(loss_info) # visualize if k % 10 == 0: grids = {} visind = range(bz) # [0] grids['images'] = torchvision.utils.make_grid( images[visind]).detach().cpu() grids['landmarks_gt'] = torchvision.utils.make_grid( util.tensor_vis_landmarks(images[visind], landmarks[visind])) grids['landmarks2d'] = torchvision.utils.make_grid( util.tensor_vis_landmarks(images[visind], landmarks2d[visind])) grids['landmarks3d'] = torchvision.utils.make_grid( util.tensor_vis_landmarks(images[visind], landmarks3d[visind])) grids['albedoimage'] = torchvision.utils.make_grid( (ops['albedo_images'])[visind].detach().cpu()) grids['render'] = torchvision.utils.make_grid( predicted_images[visind].detach().float().cpu()) shape_images = self.render.render_shape( vertices, trans_vertices, images) grids['shape'] = torchvision.utils.make_grid( F.interpolate(shape_images[visind], [224, 224])).detach().float().cpu() # grids['tex'] = torchvision.utils.make_grid(F.interpolate(albedos[visind], [224, 224])).detach().cpu() grid = torch.cat(list(grids.values()), 1) grid_image = (grid.numpy().transpose(1, 2, 0).copy() * 255)[:, :, [2, 1, 0]] grid_image = np.minimum(np.maximum(grid_image, 0), 255).astype(np.uint8) cv2.imwrite('{}/{}.jpg'.format(savefolder, k), grid_image) single_params = { 'shape': shape.detach().cpu().numpy(), 'exp': exp.detach().cpu().numpy(), 'pose': pose.detach().cpu().numpy(), 'cam': cam.detach().cpu().numpy(), 'verts': trans_vertices.detach().cpu().numpy(), 'albedos': albedos.detach().cpu().numpy(), 'tex': tex.detach().cpu().numpy(), 'lit': lights.detach().cpu().numpy() } return single_params
def decode(self, codedict, epoch): images = codedict['images'] batch_size = images.shape[0] ## decode verts, landmarks2d, landmarks3d = self.flame(shape_params=codedict['shape'], \ expression_params=codedict['exp'], pose_params=codedict['pose']) if self.config.model.use_tex: albedo = self.flametex(codedict['tex']) else: albedo = torch.zeros([batch_size, 3, self.uv_size, self.uv_size], device=images.device) ## projection landmarks2d = util.batch_orth_proj(landmarks2d, codedict['cam'])[:,:,:2] landmarks2d[:,:,1:] = -landmarks2d[:,:,1:]; landmarks2d = landmarks2d*self.image_size/2 + self.image_size/2 landmarks2d /= (self.image_size - 1) landmarks3d = util.batch_orth_proj(landmarks3d, codedict['cam']) landmarks3d[:,:,1:] = -landmarks3d[:,:,1:] landmarks3d = landmarks3d*self.image_size/2 + self.image_size/2 landmarks3d /= (self.image_size - 1) trans_verts = util.batch_orth_proj(verts, codedict['cam']) trans_verts[:,:,1:] = -trans_verts[:,:,1:] # trans_verts = trans_verts*self.image_size/2 + self.image_size/2 # trans_verts /= (self.image_size - 1) normals = util.vertex_normals(verts, self.faces.expand(batch_size, -1, -1).to(self.device)) output = {'albedo': albedo, 'verts': verts, 'trans_verts': trans_verts, \ 'landmarks2d': landmarks2d, 'landmarks3d': landmarks3d, 'normals': normals} # shape consistency if 'coarse' in self.mode and epoch>self.epoch_phase: verts, landmarks2d, landmarks3d = self.flame(shape_params=codedict['shape_shuffle'], \ expression_params=codedict['exp'], pose_params=codedict['pose']) ## projection landmarks2d = util.batch_orth_proj(landmarks2d, codedict['cam'])[:,:,:2] landmarks2d[:,:,1:] = -landmarks2d[:,:,1:] landmarks2d = landmarks2d*self.image_size/2 + self.image_size/2 landmarks2d /= (self.image_size - 1) landmarks3d = util.batch_orth_proj(landmarks3d, codedict['cam']) landmarks3d[:,:,1:] = -landmarks3d[:,:,1:] landmarks3d = landmarks3d*self.image_size/2 + self.image_size/2 landmarks3d /= (self.image_size - 1) trans_verts = util.batch_orth_proj(verts, codedict['cam']) trans_verts[:,:,1:] = -trans_verts[:,:,1:] # trans_verts = trans_verts*self.image_size/2 + self.image_size/2 # trans_verts /= (self.image_size - 1) # normals = util.vertex_normals(verts, self.faces.expand(batch_size, -1, -1)) output['landmarks2d_shuffle'] = landmarks2d output['landmarks3d_shuffle'] = landmarks3d output['verts_shuffle'] = verts output['trans_verts_shuffle'] = trans_verts # output['normals_shuffle'] = normals if self.mode == 'train_detail': uv_z = self.D_detail(torch.cat([codedict['pose'][:,3:], codedict['exp'], \ codedict['detail']], dim=1)) output['displacement_map'] = uv_z+self.fixed_uv_dis[None,None,:,:] dense_vertices, dense_faces = displacement2vertex(uv_z, verts, normals, self.unsupervised_losses_conductor.render) uv_detail_normals = displacement2normal(uv_z, verts, normals, self.unsupervised_losses_conductor.render) dense_trans_verts = util.batch_orth_proj(dense_vertices, codedict['cam']) dense_trans_verts[:,:,1:] = -dense_trans_verts[:,:,1:] output['detail_verts'] = dense_vertices output['detail_trans_verts'] = dense_trans_verts output['detail_faces'] = dense_faces output['uv_detail_normals'] = uv_detail_normals return output