def fitting(self, input_data_file): with open(input_data_file, 'rb') as f: body_param_input = pickle.load(f) xh, self.cam_ext, self.cam_int = BodyParamParser.body_params_parse_fitting( body_param_input) xhr = GeometryTransformer.convert_to_6D_rot(xh) self.xhr_rec.data = xhr.clone() for ii in range(self.num_iter): self.optimizer.zero_grad() loss_rec, loss_vposer, loss_contact, loss_collision = self.cal_loss( xhr, self.cam_ext) loss = loss_rec + loss_vposer + loss_contact + loss_collision if self.verbose: print( '[INFO][fitting] iter={:d}, l_rec={:f}, l_vposer={:f}, l_contact={:f}, l_collision={:f}' .format(ii, loss_rec.item(), loss_vposer.item(), loss_contact.item(), loss_collision.item())) loss.backward(retain_graph=True) self.optimizer.step() print('[INFO][fitting] fitting finish, returning optimal value') xh_rec = GeometryTransformer.convert_to_3D_rot(self.xhr_rec) return xh_rec
def cal_loss(self, xhr, cam_ext): ### reconstruction loss loss_rec = self.weight_loss_rec*F.l1_loss(xhr, self.xhr_rec) xh_rec = GeometryTransformer.convert_to_3D_rot(self.xhr_rec) ### vposer loss vposer_pose = xh_rec[:,16:48] loss_vposer = self.weight_loss_vposer * torch.mean(vposer_pose**2) ### contact score and non-coll score body_param_rec = BodyParamParser.body_params_encapsulate_batch(xh_rec) joint_rot_batch = self.vposer.decode(body_param_rec['body_pose_vp'], output_type='aa').view(self.batch_size, -1) body_param_ = {} for key in body_param_rec.keys(): if key in ['body_pose_vp']: continue else: body_param_[key] = body_param_rec[key] smplx_output = self.body_mesh_model(return_verts=True, body_pose=joint_rot_batch, **body_param_) body_verts_batch = smplx_output.vertices #[b, 10475,3] body_verts_batch = GeometryTransformer.verts_transform(body_verts_batch, cam_ext) s_grid_min_batch = self.s_grid_min_batch.unsqueeze(1) s_grid_max_batch = self.s_grid_max_batch.unsqueeze(1) norm_verts_batch = (body_verts_batch - s_grid_min_batch) / (s_grid_max_batch - s_grid_min_batch) *2 -1 n_verts = norm_verts_batch.shape[1] body_sdf_batch = F.grid_sample(self.s_sdf_batch.unsqueeze(1), norm_verts_batch[:,:,[2,1,0]].view(-1, n_verts,1,1,3), padding_mode='border') loss_contact = torch.tensor(1.0, dtype=torch.float32, device=self.device) if body_sdf_batch.lt(0).sum().item() < 1: # if the number of negative sdf entries is less than one loss_sdf_pene = torch.tensor(10475, dtype=torch.float32, device=self.device) loss_contact = torch.tensor(0.0, dtype=torch.float32, device=self.device) else: loss_sdf_pene = (body_sdf_batch > 0).sum() loss_collision = loss_sdf_pene.float() / 10475.0 return loss_rec, loss_vposer, loss_contact, loss_collision
def eval_colllision(self): coll_list = [] cont_list = [] for ii in range(8000): filename = os.path.join(self.input_data_file, 'body_gen_{:06d}.pkl'.format(ii)) if not os.path.exists(filename): continue with open(filename, 'rb') as f: body_param_input = pickle.load(f) xh, self.cam_ext, self.cam_int= BodyParamParser.body_params_parse_fitting(body_param_input) xhr = GeometryTransformer.convert_to_6D_rot(xh) self.xhr_rec = xhr T_mat = np.eye(4) T_mat[1,:] = np.array([0,-1,0,0]) T_mat[2,:] = np.array([0,0,-1,0]) T_mat = torch.tensor(T_mat, dtype=torch.float32, device=self.device) T_mat = T_mat.unsqueeze(0) trans = torch.matmul(self.cam_ext[:1],T_mat) loss_rec, loss_vposer, loss_contact, loss_collision = self.cal_loss(xhr, trans) coll_list.append(loss_collision.item()) cont_list.append(loss_contact.item()) return coll_list, cont_list
def fitting(self, input_data_file): with open(input_data_file, 'rb') as f: body_param_input = pickle.load(f) xh, self.cam_ext, self.cam_int= BodyParamParser.body_params_parse_fitting(body_param_input) xhr = GeometryTransformer.convert_to_6D_rot(xh) self.xhr_rec.data = xhr.clone() T_mat = np.eye(4) T_mat[1,:] = np.array([0,-1,0,0]) T_mat[2,:] = np.array([0,0,-1,0]) T_mat = torch.tensor(T_mat, dtype=torch.float32, device=self.device) T_mat = T_mat.unsqueeze(0) trans = torch.matmul(self.cam_ext[:1],T_mat) for ii in range(self.num_iter): self.optimizer.zero_grad() loss_rec, loss_vposer, loss_contact, loss_collision = self.cal_loss(xhr, trans) loss = loss_rec + loss_vposer + loss_contact + loss_collision if self.verbose: print('[INFO][fitting] iter={:d}, l_rec={:f}, l_vposer={:f}, l_contact={:f}, l_collision={:f}'.format( ii, loss_rec.item(), loss_vposer.item(), loss_contact.item(), loss_collision.item()) ) loss.backward(retain_graph=True) self.optimizer.step() print('[INFO][fitting] fitting finish, returning optimal value') xh_rec = GeometryTransformer.convert_to_3D_rot(self.xhr_rec) return xh_rec
def cal_loss(self, xs, xh, eps_g, eps_l, cam_ext, cam_int, max_d, scene_verts, scene_face, s_grid_min_batch, s_grid_max_batch, s_grid_sdf_batch, ep): # normalize global trans xhn = GeometryTransformer.normalize_global_T(xh, cam_int, max_d) # convert global rotation xhnr = GeometryTransformer.convert_to_6D_rot(xhn) [xhnr_rec, mu_g, logsigma2_g, mu_l, logsigma2_l] = self.model_h(xhnr, eps_g, eps_l, xs) xhn_rec = GeometryTransformer.convert_to_3D_rot(xhnr_rec) # recover global trans xh_rec = GeometryTransformer.recover_global_T(xhn_rec, cam_int, max_d) loss_rec_t = self.weight_loss_rec_h*( 0.5*F.l1_loss(xhnr_rec[:,:3], xhnr[:,:3]) +0.5*F.l1_loss(xh_rec[:,:3], xh[:,:3])) loss_rec_p = self.weight_loss_rec_h*F.l1_loss(xhnr_rec[:,3:], xhnr[:,3:]) ### kl divergence loss fca = 1.0 if self.loss_weight_anealing: fca = min(1.0, max(float(ep) / (self.epoch*0.75),0) ) loss_KL_g = fca**2 *self.weight_loss_kl * 0.5*torch.mean(torch.exp(logsigma2_g) +mu_g**2 -1.0 -logsigma2_g) loss_KL_l = fca**2 *self.weight_loss_kl * 0.5*torch.mean(torch.exp(logsigma2_l) +mu_l**2 -1.0 -logsigma2_l) ### Vposer loss vposer_pose = xh_rec[:,16:48] loss_vposer = self.weight_loss_vposer * torch.mean(vposer_pose**2) ### contact loss ## (1) get the reconstructed body mesh body_param_rec = BodyParamParser.body_params_encapsulate_batch(xh_rec) joint_rot_batch = self.vposer.decode(body_param_rec['body_pose_vp'], output_type='aa').view(self.batch_size, -1) body_param_ = {} for key in body_param_rec.keys(): if key in ['body_pose_vp']: continue else: body_param_[key] = body_param_rec[key] smplx_output = self.body_mesh_model(return_verts=True, body_pose=joint_rot_batch, **body_param_) body_verts_batch = smplx_output.vertices #[b, 10475,3] body_verts_batch = GeometryTransformer.verts_transform(body_verts_batch, cam_ext) ## (2) select contact vertex according to prox annotation vid, fid = GeometryTransformer.get_contact_id(body_segments_folder=self.contact_id_folder, contact_body_parts=self.contact_part) body_verts_contact_batch = body_verts_batch[:, vid, :] ## (3) compute chamfer loss between pcd_batch and body_verts_batch dist_chamfer_contact = ext.chamferDist() contact_dist, _ = dist_chamfer_contact(body_verts_contact_batch.contiguous(), scene_verts.contiguous()) fcc = 0.0 if ep > 0.75*self.epoch: fcc = 1.0 loss_contact = fcc *self.weight_contact * torch.mean(torch.sqrt(contact_dist+1e-4)/(torch.sqrt(contact_dist+1e-4)+1.0) ) ### SDF scene penetration loss s_grid_min_batch = s_grid_min_batch.unsqueeze(1) s_grid_max_batch = s_grid_max_batch.unsqueeze(1) norm_verts_batch = (body_verts_batch - s_grid_min_batch) / (s_grid_max_batch - s_grid_min_batch) *2 -1 n_verts = norm_verts_batch.shape[1] body_sdf_batch = F.grid_sample(s_grid_sdf_batch.unsqueeze(1), norm_verts_batch[:,:,[2,1,0]].view(-1, n_verts,1,1,3), padding_mode='border') # if there are no penetrating vertices then set sdf_penetration_loss = 0 if body_sdf_batch.lt(0).sum().item() < 1: loss_sdf_pene = torch.tensor(0.0, dtype=torch.float32, device=self.device) else: loss_sdf_pene = body_sdf_batch[body_sdf_batch < 0].abs().mean() fsp = 0.0 if ep > 0.75*self.epoch: fsp = 1.0 loss_sdf_pene = fsp*self.weight_collision*loss_sdf_pene return loss_rec_t, loss_rec_p, loss_KL_g, loss_KL_l, loss_contact, loss_vposer, loss_sdf_pene
def test(self): self.model_h.eval() self.model_h.to(self.device) self.vposer.to(self.device) self.body_mesh_model.to(self.device) ## load checkpoints ckp_list = sorted(glob.glob(os.path.join(self.ckpt_dir, 'epoch-*.ckp')), key=os.path.getmtime) ckp_path = ckp_list[-1] checkpoint = torch.load(ckp_path) print('[INFO] load checkpoints: ' + ckp_path) self.model_h.load_state_dict(checkpoint['model_h_state_dict']) ## read data from here! cam_file_list = sorted(glob.glob(self.test_data_path + '/cam_*')) for ii, cam_file in enumerate(cam_file_list): cam_params = np.load(cam_file, allow_pickle=True, encoding='latin1').item() depth0 = torch.tensor(np.load(cam_file.replace('cam', 'depth')), dtype=torch.float32, device=self.device) seg0 = torch.tensor(np.load(cam_file.replace('cam', 'seg')), dtype=torch.float32, device=self.device) cam_ext = torch.tensor(cam_params['cam_ext'], dtype=torch.float32, device=self.device).unsqueeze(0) #[1,4,4] cam_int = torch.tensor(cam_params['cam_int'], dtype=torch.float32, device=self.device).unsqueeze(0) # [1,3,3] depth, _, max_d = self.data_preprocessing( depth0, 'depth', target_domain_size=[128, 128]) #[1,1,128,128] max_d = max_d.view(1) seg, _, _ = self.data_preprocessing(seg0, 'depth', target_domain_size=[128, 128]) xs = torch.cat([depth, seg], dim=1) xs_batch = xs.repeat(self.n_samples, 1, 1, 1) max_d_batch = max_d.repeat(self.n_samples) cam_int_batch = cam_int.repeat(self.n_samples, 1, 1) cam_ext_batch = cam_ext.repeat(self.n_samples, 1, 1) xhnr_gen = self.model_h.sample(xs_batch) xhn_gen = GeometryTransformer.convert_to_3D_rot(xhnr_gen) xh_gen = GeometryTransformer.recover_global_T( xhn_gen, cam_int_batch, max_d_batch) body_param_list = BodyParamParser.body_params_encapsulate(xh_gen) if not os.path.exists(self.outdir): os.makedirs(self.outdir) print('[INFO] save results to: ' + self.outdir) for jj, body_param in enumerate(body_param_list): body_param['cam_ext'] = cam_ext_batch.detach().cpu().numpy() body_param['cam_int'] = cam_int_batch.detach().cpu().numpy() outfilename = os.path.join( self.outdir, 'body_gen_{:06d}.pkl'.format(self.n_samples * ii + jj)) outfile = open(outfilename, 'wb') pickle.dump(body_param, outfile) outfile.close()
def test(self, batch_gen): self.model_h.eval() self.model_h.to(self.device) self.vposer.to(self.device) self.body_mesh_model.to(self.device) ## load checkpoints ckp_list = sorted(glob.glob(os.path.join(self.ckpt_dir, 'epoch-*.ckp')), key=os.path.getmtime) ckp_path = ckp_list[-1] checkpoint = torch.load(ckp_path) print('[INFO] load checkpoints: ' + ckp_path) self.model_h.load_state_dict(checkpoint['model_h_state_dict']) ## get a batch of data for testing batch_gen.reset() test_data = batch_gen.next_batch(batch_size=1) depth_batch = test_data[0] seg_batch = test_data[1] max_d_batch = test_data[2] cam_int_batch = test_data[3] cam_ext_batch = test_data[4] ## pass data to network xs = torch.cat([depth_batch, seg_batch], dim=1) xs_n = xs.repeat(self.n_samples, 1, 1, 1) noise_batch_g = torch.randn([self.n_samples, self.model_h_latentD], dtype=torch.float32, device=self.device) noise_batch_l = torch.randn([self.n_samples, self.model_h_latentD], dtype=torch.float32, device=self.device) if self.use_cont_rot: xhnr_gen = self.model_h.sample(xs_n, noise_batch_g, noise_batch_l) xhn_gen = GeometryTransformer.convert_to_3D_rot(xhnr_gen) else: xhnr_gen = self.model_h.sample(xs_n, noise_batch_g, noise_batch_l) xh_gen = GeometryTransformer.recover_global_T(xhn_gen, cam_int_batch, max_d_batch) body_param_list = BodyParamParser.body_params_encapsulate(xh_gen) scene_name = os.path.abspath( self.scene_file_path).split("/")[-2].split("_")[0] outdir = os.path.join(self.output_dir, scene_name) if not os.path.exists(outdir): os.makedirs(outdir) print('[INFO] save results to: ' + outdir) for ii, body_param in enumerate(body_param_list): body_param['cam_ext'] = cam_ext_batch.detach().cpu().numpy() body_param['cam_int'] = cam_int_batch.detach().cpu().numpy() outfilename = os.path.join(outdir, 'body_gen_{:06d}.pkl'.format(ii + 900)) outfile = open(outfilename, 'wb') pickle.dump(body_param, outfile) outfile.close()
def cal_loss(self, xhr, cam_ext): ### reconstruction loss loss_rec = self.weight_loss_rec*F.l1_loss(xhr, self.xhr_rec) xh_rec = GeometryTransformer.convert_to_3D_rot(self.xhr_rec) ### vposer loss vposer_pose = xh_rec[:,16:48] loss_vposer = self.weight_loss_vposer * torch.mean(vposer_pose**2) ### contact loss body_param_rec = BodyParamParser.body_params_encapsulate_batch(xh_rec) joint_rot_batch = self.vposer.decode(body_param_rec['body_pose_vp'], output_type='aa').view(self.batch_size, -1) body_param_ = {} for key in body_param_rec.keys(): if key in ['body_pose_vp']: continue else: body_param_[key] = body_param_rec[key] smplx_output = self.body_mesh_model(return_verts=True, body_pose=joint_rot_batch, **body_param_) body_verts_batch = smplx_output.vertices #[b, 10475,3] body_verts_batch = GeometryTransformer.verts_transform(body_verts_batch, cam_ext) vid, fid = GeometryTransformer.get_contact_id( body_segments_folder=self.contact_id_folder, contact_body_parts=self.contact_part) body_verts_contact_batch = body_verts_batch[:, vid, :] dist_chamfer_contact = ext.chamferDist() contact_dist, _ = dist_chamfer_contact(body_verts_contact_batch.contiguous(), self.s_verts_batch.contiguous()) loss_contact = self.weight_contact * torch.mean(torch.sqrt(contact_dist+1e-4)/(torch.sqrt(contact_dist+1e-4)+1.0)) ### sdf collision loss s_grid_min_batch = self.s_grid_min_batch.unsqueeze(1) s_grid_max_batch = self.s_grid_max_batch.unsqueeze(1) norm_verts_batch = (body_verts_batch - s_grid_min_batch) / (s_grid_max_batch - s_grid_min_batch) *2 -1 n_verts = norm_verts_batch.shape[1] body_sdf_batch = F.grid_sample(self.s_sdf_batch.unsqueeze(1), norm_verts_batch[:,:,[2,1,0]].view(-1, n_verts,1,1,3), padding_mode='border') # if there are no penetrating vertices then set sdf_penetration_loss = 0 if body_sdf_batch.lt(0).sum().item() < 1: loss_sdf_pene = torch.tensor(0.0, dtype=torch.float32, device=self.device) else: loss_sdf_pene = body_sdf_batch[body_sdf_batch < 0].abs().mean() loss_collision = self.weight_collision*loss_sdf_pene return loss_rec, loss_vposer, loss_contact, loss_collision