def forward_step(th_scan_meshes, smpl, th_pose_3d=None): """ Performs a forward step, given smpl and scan meshes. Then computes the losses. """ # Get pose prior prior = get_prior(smpl.gender) # forward verts, _, _, _ = smpl() th_smpl_meshes = [ tm.from_tensors(vertices=v, faces=smpl.faces) for v in verts ] # losses loss = dict() loss['s2m'] = batch_point_to_surface( [sm.vertices for sm in th_scan_meshes], th_smpl_meshes) loss['m2s'] = batch_point_to_surface( [sm.vertices for sm in th_smpl_meshes], th_scan_meshes) loss['betas'] = torch.mean(smpl.betas**2, axis=1) loss['pose_pr'] = prior(smpl.pose) if th_pose_3d is not None: loss['pose_obj'] = batch_get_pose_obj(th_pose_3d, smpl) return loss
def forward_step_pose_only(smpl, th_pose_3d, prior_weight): """ Performs a forward step, given smpl and scan meshes. Then computes the losses. """ # Get pose prior prior = get_prior(smpl.gender) # losses loss = dict() loss['pose_pr'] = prior(smpl.pose, prior_weight) loss['pose_obj'] = batch_get_pose_obj(th_pose_3d, smpl, init_pose=False) return loss
def forward_step(th_scan_meshes, smplx, scan_part_labels, smplx_part_labels, search_tree=None, pen_distance=None, tri_filtering_module=None): """ Performs a forward step, given smplx and scan meshes. Then computes the losses. """ # Get pose prior prior = get_prior(smplx.gender, precomputed=True) # forward # verts, _, _, _ = smplx() verts = smplx() th_smplx_meshes = [ tm.from_tensors(vertices=v, faces=smplx.faces) for v in verts ] scan_verts = [sm.vertices for sm in th_scan_meshes] smplx_verts = [sm.vertices for sm in th_smplx_meshes] # losses loss = dict() loss['s2m'] = batch_point_to_surface(scan_verts, th_smplx_meshes) loss['m2s'] = batch_point_to_surface(smplx_verts, th_scan_meshes) loss['betas'] = torch.mean(smplx.betas**2, axis=1) # loss['pose_pr'] = prior(smplx.pose) loss['interpenetration'] = interpenetration_loss(verts, smplx.faces, search_tree, pen_distance, tri_filtering_module, 1.0) loss['part'] = [] for n, (sc_v, sc_l) in enumerate(zip(scan_verts, scan_part_labels)): tot = 0 for i in range(NUM_PARTS): # we currently use 14 parts if i not in sc_l: continue ind = torch.where(sc_l == i)[0] sc_part_points = sc_v[ind].unsqueeze(0) sm_part_points = smplx_verts[n][torch.where( smplx_part_labels[n] == i)[0]].unsqueeze(0) dist = chamfer_distance(sc_part_points, sm_part_points, w1=1., w2=1.) tot += dist loss['part'].append(tot / NUM_PARTS) loss['part'] = torch.stack(loss['part']) return loss
def forward_step_SMPL(th_scan_meshes, smpl, scan_part_labels, smpl_part_labels, args): """ Performs a forward step, given smpl and scan meshes. Then computes the losses. """ # Get pose prior prior = get_prior(smpl.gender, precomputed=True) # forward verts, _, _, _ = smpl() th_smpl_meshes = [tm.from_tensors(vertices=v, faces=smpl.faces) for v in verts] scan_verts = [sm.vertices for sm in th_scan_meshes] smpl_verts = [sm.vertices for sm in th_smpl_meshes] # losses loss = dict() loss['s2m'] = batch_point_to_surface(scan_verts, th_smpl_meshes) loss['m2s'] = batch_point_to_surface(smpl_verts, th_scan_meshes) loss['betas'] = torch.mean(smpl.betas ** 2, axis=1) loss['pose_pr'] = prior(smpl.pose) # if args.num_joints == 14: if args.use_parts: loss['part'] = [] for n, (sc_v, sc_l) in enumerate(zip(scan_verts, scan_part_labels)): tot = 0 # for i in range(args.num_joints): # we currently use 14 parts for i in range(14): # we currently use 14 parts if i not in sc_l: continue ind = torch.where(sc_l == i)[0] sc_part_points = sc_v[ind].unsqueeze(0) sm_part_points = smpl_verts[n][torch.where(smpl_part_labels[n] == i)[0]].unsqueeze(0) dist = chamfer_distance(sc_part_points, sm_part_points, w1=1., w2=1.) tot += dist # loss['part'].append(tot / args.num_joints) loss['part'].append(tot / 14) loss['part'] = torch.stack(loss['part']) return loss
def fit_SMPLX(scans, pose_files=None, gender='male', save_path=None, display=None): """ :param save_path: :param scans: list of scan paths :param pose_files: :return: """ # Get SMPL faces sp = SmplPaths(gender=gender) smpl_faces = sp.get_faces() th_faces = torch.tensor(smpl_faces.astype('float32'), dtype=torch.long).cuda() # Batch size batch_sz = len(scans) # Set optimization hyper parameters iterations, pose_iterations, steps_per_iter, pose_steps_per_iter = 3, 2, 30, 30 if False: """Test by loading GT SMPL params""" betas, pose, trans = torch.tensor( GT_SMPL['betas'].astype('float32')).unsqueeze(0), torch.tensor( GT_SMPL['pose'].astype('float32')).unsqueeze(0), torch.zeros( (batch_sz, 3)) else: prior = get_prior(gender=gender) pose_init = torch.zeros((batch_sz, 72)) pose_init[:, 3:] = prior.mean betas, pose, trans = torch.zeros( (batch_sz, 300)), pose_init, torch.zeros((batch_sz, 3)) # Init SMPL, pose with mean smpl pose, as in ch.registration smpl = th_batch_SMPL(batch_sz, betas, pose, trans, faces=th_faces).cuda() # Load scans and center them. Once smpl is registered, move it accordingly. # Do not forget to change the location of 3D joints/ landmarks accordingly. th_scan_meshes, centers = [], [] for scan in scans: print('scan path ...', scan) th_scan = tm.from_obj(scan) # cent = th_scan.vertices.mean(axis=0) # centers.append(cent) # th_scan.vertices -= cent th_scan.vertices = th_scan.vertices.cuda() th_scan.faces = th_scan.faces.cuda() th_scan.vertices.requires_grad = False th_scan.cuda() th_scan_meshes.append(th_scan) # Load pose information if pose file is given # Bharat: Shouldn't we structure th_pose_3d as [key][batch, ...] as opposed to current [batch][key]? See batch_get_pose_obj() in body_objectives.py th_pose_3d = None if pose_files is not None: th_no_right_hand_visible, th_no_left_hand_visible, th_pose_3d = [], [], [] for pose_file in pose_files: with open(pose_file) as f: pose_3d = json.load(f) th_no_right_hand_visible.append( np.max( np.array(pose_3d['hand_right_keypoints_3d']).reshape( -1, 4)[:, 3]) < HAND_VISIBLE) th_no_left_hand_visible.append( np.max( np.array(pose_3d['hand_left_keypoints_3d']).reshape( -1, 4)[:, 3]) < HAND_VISIBLE) pose_3d['pose_keypoints_3d'] = torch.from_numpy( np.array(pose_3d['pose_keypoints_3d']).astype( np.float32).reshape(-1, 4)) pose_3d['face_keypoints_3d'] = torch.from_numpy( np.array(pose_3d['face_keypoints_3d']).astype( np.float32).reshape(-1, 4)) pose_3d['hand_right_keypoints_3d'] = torch.from_numpy( np.array(pose_3d['hand_right_keypoints_3d']).astype( np.float32).reshape(-1, 4)) pose_3d['hand_left_keypoints_3d'] = torch.from_numpy( np.array(pose_3d['hand_left_keypoints_3d']).astype( np.float32).reshape(-1, 4)) th_pose_3d.append(pose_3d) prior_weight = get_prior_weight(th_no_right_hand_visible, th_no_left_hand_visible).cuda() # Optimize pose first optimize_pose_only(th_scan_meshes, smpl, pose_iterations, pose_steps_per_iter, th_pose_3d, prior_weight, display=None if display is None else 0) # Optimize pose and shape optimize_pose_shape(th_scan_meshes, smpl, iterations, steps_per_iter, th_pose_3d, display=None if display is None else 0) verts, _, _, _ = smpl() th_smpl_meshes = [ tm.from_tensors(vertices=v, faces=smpl.faces) for v in verts ] if save_path is not None: if not exists(save_path): os.makedirs(save_path) names = [split(s)[1] for s in scans] # Save meshes save_meshes( th_smpl_meshes, [join(save_path, n.replace('.obj', '_smpl.obj')) for n in names]) save_meshes(th_scan_meshes, [join(save_path, n) for n in names]) # Save params for p, b, t, n in zip(smpl.pose.cpu().detach().numpy(), smpl.betas.cpu().detach().numpy(), smpl.trans.cpu().detach().numpy(), names): smpl_dict = {'pose': p, 'betas': b, 'trans': t} pkl.dump( smpl_dict, open(join(save_path, n.replace('.obj', '_smpl.pkl')), 'wb')) return smpl.pose.cpu().detach().numpy(), smpl.betas.cpu().detach( ).numpy(), smpl.trans.cpu().detach().numpy()
def fit_SMPL(scans, scan_labels, gender='male', save_path=None, scale_file=None, display=None): """ :param save_path: :param scans: list of scan paths :param pose_files: :return: """ # Get SMPL faces sp = SmplPaths(gender=gender) smpl_faces = sp.get_faces() th_faces = torch.tensor(smpl_faces.astype('float32'), dtype=torch.long).to(DEVICE) # Load SMPL parts part_labels = pkl.load(open('/BS/bharat-3/work/IPNet/assets/smpl_parts_dense.pkl', 'rb')) labels = np.zeros((6890,), dtype='int32') for n, k in enumerate(part_labels): labels[part_labels[k]] = n labels = torch.tensor(labels).unsqueeze(0).to(DEVICE) # Load scan parts scan_part_labels = [] for sc_l in scan_labels: temp = torch.tensor(np.load(sc_l).astype('int32')).to(DEVICE) scan_part_labels.append(temp) # Batch size batch_sz = len(scans) # Set optimization hyper parameters iterations, pose_iterations, steps_per_iter, pose_steps_per_iter = 3, 2, 30, 30 prior = get_prior(gender=gender, precomputed=True) pose_init = torch.zeros((batch_sz, 72)) pose_init[:, 3:] = prior.mean betas, pose, trans = torch.zeros((batch_sz, 300)), pose_init, torch.zeros((batch_sz, 3)) # Init SMPL, pose with mean smpl pose, as in ch.registration smpl = th_batch_SMPL(batch_sz, betas, pose, trans, faces=th_faces).to(DEVICE) smpl_part_labels = torch.cat([labels] * batch_sz, axis=0) th_scan_meshes, centers = [], [] for scan in scans: print('scan path ...', scan) temp = Mesh(filename=scan) th_scan = tm.from_tensors(torch.tensor(temp.v.astype('float32'), requires_grad=False, device=DEVICE), torch.tensor(temp.f.astype('int32'), requires_grad=False, device=DEVICE).long()) th_scan_meshes.append(th_scan) if scale_file is not None: for n, sc in enumerate(scale_file): dat = np.load(sc, allow_pickle=True) th_scan_meshes[n].vertices += torch.tensor(dat[1]).to(DEVICE) th_scan_meshes[n].vertices *= torch.tensor(dat[0]).to(DEVICE) # Optimize pose first optimize_pose_only(th_scan_meshes, smpl, pose_iterations, pose_steps_per_iter, scan_part_labels, smpl_part_labels, display=None if display is None else 0) # Optimize pose and shape optimize_pose_shape(th_scan_meshes, smpl, iterations, steps_per_iter, scan_part_labels, smpl_part_labels, display=None if display is None else 0) verts, _, _, _ = smpl() th_smpl_meshes = [tm.from_tensors(vertices=v, faces=smpl.faces) for v in verts] if save_path is not None: if not exists(save_path): os.makedirs(save_path) names = [split(s)[1] for s in scans] # Save meshes save_meshes(th_smpl_meshes, [join(save_path, n.replace('.ply', '_smpl.obj')) for n in names]) save_meshes(th_scan_meshes, [join(save_path, n) for n in names]) # Save params for p, b, t, n in zip(smpl.pose.cpu().detach().numpy(), smpl.betas.cpu().detach().numpy(), smpl.trans.cpu().detach().numpy(), names): smpl_dict = {'pose': p, 'betas': b, 'trans': t} pkl.dump(smpl_dict, open(join(save_path, n.replace('.ply', '_smpl.pkl')), 'wb')) return smpl.pose.cpu().detach().numpy(), smpl.betas.cpu().detach().numpy(), smpl.trans.cpu().detach().numpy()
def SMPLD_register(args): cfg = config.load_config(args.config, 'configs/default.yaml') out_dir = cfg['training']['out_dir'] generation_dir = os.path.join(out_dir, cfg['generation']['generation_dir']) is_cuda = (torch.cuda.is_available() and not args.no_cuda) device = torch.device("cuda" if is_cuda else "cpu") if args.subject_idx >= 0 and args.sequence_idx >= 0: logger, _ = create_logger(generation_dir, phase='reg_subject{}_sequence{}'.format(args.subject_idx, args.sequence_idx), create_tf_logs=False) else: logger, _ = create_logger(generation_dir, phase='reg_all', create_tf_logs=False) # Get dataset if args.subject_idx >= 0 and args.sequence_idx >= 0: dataset = config.get_dataset('test', cfg, sequence_idx=args.sequence_idx, subject_idx=args.subject_idx) else: dataset = config.get_dataset('test', cfg) batch_size = cfg['generation']['batch_size'] # Loader test_loader = torch.utils.data.DataLoader( dataset, batch_size=batch_size, num_workers=1, shuffle=False) model_counter = defaultdict(int) # Set optimization hyper parameters iterations, pose_iterations, steps_per_iter, pose_steps_per_iter = 3, 2, 30, 30 inner_dists = [] outer_dists = [] for it, data in enumerate(tqdm(test_loader)): idxs = data['idx'].cpu().numpy() loc = data['points.loc'].cpu().numpy() batch_size = idxs.shape[0] # Directories to load corresponding informations mesh_dir = os.path.join(generation_dir, 'meshes') # directory for posed and (optionally) unposed implicit outer/inner meshes label_dir = os.path.join(generation_dir, 'labels') # directory for part labels register_dir = os.path.join(generation_dir, 'registrations') # directory for part labels if args.use_raw_scan: scan_dir = dataset.dataset_folder # this is the folder that contains CAPE raw scans else: scan_dir = None all_posed_minimal_meshes = [] all_posed_cloth_meshes = [] all_posed_vertices = [] all_unposed_vertices = [] scan_part_labels = [] for idx in idxs: model_dict = dataset.get_model_dict(idx) subset = model_dict['subset'] subject = model_dict['subject'] sequence = model_dict['sequence'] gender = model_dict['gender'] filebase = os.path.basename(model_dict['data_path'])[:-4] folder_name = os.path.join(subset, subject, sequence) # TODO: we assume batch size stays the same if one resumes the job # can be more flexible to support different batch sizes before and # after resume register_file = os.path.join(register_dir, folder_name, filebase + 'minimal.registered.ply') if os.path.exists(register_file): # batch already computed, break break # points_dict = np.load(model_dict['data_path']) # gender = str(points_dict['gender']) mesh_dir_ = os.path.join(mesh_dir, folder_name) label_dir_ = os.path.join(label_dir, folder_name) if scan_dir is not None: scan_dir_ = os.path.join(scan_dir, subject, sequence) # Load part labels and vertex translations label_file_name = filebase + '.minimal.npz' label_dict = dict(np.load(os.path.join(label_dir_, label_file_name))) labels = torch.tensor(label_dict['part_labels'].astype(np.int64)).to(device) # part labels for each vertex (14 or 24) scan_part_labels.append(labels) # Load minimal implicit surfaces mesh_file_name = filebase + '.minimal.posed.ply' # posed_mesh = Mesh(filename=os.path.join(mesh_dir_, mesh_file_name)) posed_mesh = trimesh.load(os.path.join(mesh_dir_, mesh_file_name), process=False) posed_vertices = np.array(posed_mesh.vertices) all_posed_vertices.append(posed_vertices) posed_mesh = tm.from_tensors(torch.tensor(posed_mesh.vertices.astype('float32'), requires_grad=False, device=device), torch.tensor(posed_mesh.faces.astype('int64'), requires_grad=False, device=device)) all_posed_minimal_meshes.append(posed_mesh) mesh_file_name = filebase + '.minimal.unposed.ply' if os.path.exists(os.path.join(mesh_dir_, mesh_file_name)) and args.init_pose: # unposed_mesh = Mesh(filename=os.path.join(mesh_dir_, mesh_file_name)) unposed_mesh = trimesh.load(os.path.join(mesh_dir_, mesh_file_name), process=False) unposed_vertices = np.array(unposed_mesh.vertices) all_unposed_vertices.append(unposed_vertices) if args.use_raw_scan: # Load raw scans mesh_file_name = filebase + '.ply' # posed_mesh = Mesh(filename=os.path.join(scan_dir_, mesh_file_name)) posed_mesh = trimesh.load(os.path.join(scan_dir_, mesh_file_name), process=False) posed_mesh = tm.from_tensors(torch.tensor(posed_mesh.vertices.astype('float32') / 1000, requires_grad=False, device=device), torch.tensor(posed_mesh.faces.astype('int64'), requires_grad=False, device=device)) all_posed_cloth_meshes.append(posed_mesh) else: # Load clothed implicit surfaces mesh_file_name = filebase + '.cloth.posed.ply' # posed_mesh = Mesh(filename=os.path.join(mesh_dir_, mesh_file_name)) posed_mesh = trimesh.load(os.path.join(mesh_dir_, mesh_file_name), process=False) posed_mesh = tm.from_tensors(torch.tensor(posed_mesh.vertices.astype('float32'), requires_grad=False, device=device), torch.tensor(posed_mesh.faces.astype('int64'), requires_grad=False, device=device)) all_posed_cloth_meshes.append(posed_mesh) if args.num_joints == 24: bm = BodyModel(bm_path='body_models/smpl/male/model.pkl', num_betas=10, batch_size=batch_size).to(device) parents = bm.kintree_table[0].detach().cpu().numpy() labels = bm.weights.argmax(1) # Convert 24 parts to 14 parts smpl2ipnet = torch.from_numpy(SMPL2IPNET_IDX).to(device) labels = smpl2ipnet[labels].clone().unsqueeze(0) del bm elif args.num_joints == 14: with open('body_models/misc/smpl_parts_dense.pkl', 'rb') as f: part_labels = pkl.load(f) labels = np.zeros((6890,), dtype=np.int64) for n, k in enumerate(part_labels): labels[part_labels[k]] = n labels = torch.tensor(labels).to(device).unsqueeze(0) else: raise ValueError('Got {} joints but umber of joints can only be either 14 or 24'.format(args.num_joints)) th_faces = torch.tensor(smpl_faces.astype('float32'), dtype=torch.long).to(device) # We assume loaded meshes are properly scaled and offsetted to the orignal SMPL space, if len(all_posed_minimal_meshes) > 0 and len(all_unposed_vertices) == 0: # IPNet optimization without vertex traslation # raise NotImplementedError('Optimization for IPNet is not implemented yet.') if args.num_joints == 24: for idx in range(len(scan_part_labels)): scan_part_labels[idx] = smpl2ipnet[scan_part_labels[idx]].clone() prior = get_prior(gender=gender, precomputed=True) pose_init = torch.zeros((batch_size, 72)) pose_init[:, 3:] = prior.mean betas, pose, trans = torch.zeros((batch_size, 10)), pose_init, torch.zeros((batch_size, 3)) # Init SMPL, pose with mean smpl pose, as in ch.registration smpl = th_batch_SMPL(batch_size, betas, pose, trans, faces=th_faces, gender=gender).to(device) smpl_part_labels = torch.cat([labels] * batch_size, axis=0) # Optimize pose first optimize_pose_only(all_posed_minimal_meshes, smpl, pose_iterations, pose_steps_per_iter, scan_part_labels, smpl_part_labels, None, args) # Optimize pose and shape optimize_pose_shape(all_posed_minimal_meshes, smpl, iterations, steps_per_iter, scan_part_labels, smpl_part_labels, None, args) inner_vertices, _, _, _ = smpl() # Optimize vertices for SMPLD init_smpl_meshes = [tm.from_tensors(vertices=v.clone().detach(), faces=smpl.faces) for v in inner_vertices] optimize_offsets(all_posed_cloth_meshes, smpl, init_smpl_meshes, 5, 10, args) outer_vertices, _, _, _ = smpl() elif len(all_posed_minimal_meshes) > 0: # NASA+PTFs optimization with vertex traslations # Compute poses from implicit surfaces and correspondences # TODO: we could also compute bone-lengths if we train PTFs to predict A-pose with a global translation # that equals to the centroid of the pointcloud poses = compute_poses(all_posed_vertices, all_unposed_vertices, scan_part_labels, parents, args) # Convert 24 parts to 14 parts for idx in range(len(scan_part_labels)): scan_part_labels[idx] = smpl2ipnet[scan_part_labels[idx]].clone() pose_init = torch.from_numpy(poses).float() betas, pose, trans = torch.zeros((batch_size, 10)), pose_init, torch.zeros((batch_size, 3)) # Init SMPL, pose with mean smpl pose, as in ch.registration smpl = th_batch_SMPL(batch_size, betas, pose, trans, faces=th_faces, gender=gender).to(device) smpl_part_labels = torch.cat([labels] * batch_size, axis=0) # Optimize pose first optimize_pose_only(all_posed_minimal_meshes, smpl, pose_iterations, pose_steps_per_iter, scan_part_labels, smpl_part_labels, None, args) # Optimize pose and shape optimize_pose_shape(all_posed_minimal_meshes, smpl, iterations, steps_per_iter, scan_part_labels, smpl_part_labels, None, args) inner_vertices, _, _, _ = smpl() # Optimize vertices for SMPLD init_smpl_meshes = [tm.from_tensors(vertices=v.clone().detach(), faces=smpl.faces) for v in inner_vertices] optimize_offsets(all_posed_cloth_meshes, smpl, init_smpl_meshes, 5, 10, args) outer_vertices, _, _, _ = smpl() else: inner_vertices = outer_vertices = None if args.use_raw_scan: for i, idx in enumerate(idxs): model_dict = dataset.get_model_dict(idx) subset = model_dict['subset'] subject = model_dict['subject'] sequence = model_dict['sequence'] filebase = os.path.basename(model_dict['data_path'])[:-4] folder_name = os.path.join(subset, subject, sequence) register_dir_ = os.path.join(register_dir, folder_name) if not os.path.exists(register_dir_): os.makedirs(register_dir_) if not os.path.exists(os.path.join(register_dir_, filebase + 'minimal.registered.ply')): registered_mesh = trimesh.Trimesh(inner_vertices[i].detach().cpu().numpy().astype(np.float64), smpl_faces, process=False) registered_mesh.export(os.path.join(register_dir_, filebase + 'minimal.registered.ply')) if not os.path.exists(os.path.join(register_dir_, filebase + 'cloth.registered.ply')): registered_mesh = trimesh.Trimesh(outer_vertices[i].detach().cpu().numpy().astype(np.float64), smpl_faces, process=False) registered_mesh.export(os.path.join(register_dir_, filebase + 'cloth.registered.ply')) else: # Evaluate registered mesh gt_smpl_mesh = data['points.minimal_smpl_vertices'].to(device) gt_smpld_mesh = data['points.smpl_vertices'].to(device) if inner_vertices is None: # if vertices are None, we assume they already exist due to previous runs inner_vertices = [] outer_vertices = [] for i, idx in enumerate(idxs): model_dict = dataset.get_model_dict(idx) subset = model_dict['subset'] subject = model_dict['subject'] sequence = model_dict['sequence'] filebase = os.path.basename(model_dict['data_path'])[:-4] folder_name = os.path.join(subset, subject, sequence) register_dir_ = os.path.join(register_dir, folder_name) # registered_mesh = Mesh(filename=os.path.join(register_dir_, filebase + 'minimal.registered.ply')) registered_mesh = trimesh.load(os.path.join(register_dir_, filebase + 'minimal.registered.ply'), process=False) registered_v = torch.tensor(registered_mesh.vertices.astype(np.float32), requires_grad=False, device=device) inner_vertices.append(registered_v) # registered_mesh = Mesh(filename=os.path.join(register_dir_, filebase + 'cloth.registered.ply')) registered_mesh = trimesh.load(os.path.join(register_dir_, filebase + 'cloth.registered.ply'), process=False) registered_v = torch.tensor(registered_mesh.vertices.astype(np.float32), requires_grad=False, device=device) outer_vertices.append(registered_v) inner_vertices = torch.stack(inner_vertices, dim=0) outer_vertices = torch.stack(outer_vertices, dim=0) inner_dist = torch.norm(gt_smpl_mesh - inner_vertices, dim=2).mean(-1) outer_dist = torch.norm(gt_smpld_mesh - outer_vertices, dim=2).mean(-1) for i, idx in enumerate(idxs): model_dict = dataset.get_model_dict(idx) subset = model_dict['subset'] subject = model_dict['subject'] sequence = model_dict['sequence'] filebase = os.path.basename(model_dict['data_path'])[:-4] folder_name = os.path.join(subset, subject, sequence) register_dir_ = os.path.join(register_dir, folder_name) if not os.path.exists(register_dir_): os.makedirs(register_dir_) logger.info('Inner distance for input {}: {} cm'.format(filebase, inner_dist[i].item())) logger.info('Outer distance for input {}: {} cm'.format(filebase, outer_dist[i].item())) if not os.path.exists(os.path.join(register_dir_, filebase + 'minimal.registered.ply')): registered_mesh = trimesh.Trimesh(inner_vertices[i].detach().cpu().numpy().astype(np.float64), smpl_faces, process=False) registered_mesh.export(os.path.join(register_dir_, filebase + 'minimal.registered.ply')) if not os.path.exists(os.path.join(register_dir_, filebase + 'cloth.registered.ply')): registered_mesh = trimesh.Trimesh(outer_vertices[i].detach().cpu().numpy().astype(np.float64), smpl_faces, process=False) registered_mesh.export(os.path.join(register_dir_, filebase + 'cloth.registered.ply')) inner_dists.extend(inner_dist.detach().cpu().numpy()) outer_dists.extend(outer_dist.detach().cpu().numpy()) logger.info('Mean inner distance: {} cm'.format(np.mean(inner_dists))) logger.info('Mean outer distance: {} cm'.format(np.mean(outer_dists)))
def __init__(self, model, device, train_loader, val_loader, exp_name, opt_dict={}, optimizer='Adam', checkpoint_number=-1, train_supervised=False): """ :param model: correspondence prediction network :param device: cuda or cpu :param train_loader: :param val_loader: :param exp_name: :param opt_dict: dict containing optimization specific parameteres :param optimizer: :param checkpoint_number: load a specific checkpoint, -1 => load latest :param train_supervised: """ self.model = model.to(device) self.device = device self.opt_dict = self.parse_opt_dict(opt_dict) self.optimizer_type = optimizer self.optimizer = self.init_optimizer(optimizer, self.model.parameters(), learning_rate=0.001) self.train_data_loader = train_loader self.val_data_loader = val_loader self.train_supervised = train_supervised self.checkpoint_number = checkpoint_number # Load vsmpl self.vsmpl = VolumetricSMPL( '/BS/bharat-2/work/LearntRegistration/test_data/volumetric_smpl_function_64', device, 'male') sp = SmplPaths(gender='male') self.ref_smpl = sp.get_smpl() self.template_points = torch.tensor(trimesh.Trimesh( vertices=self.ref_smpl.r, faces=self.ref_smpl.f).sample(NUM_POINTS).astype('float32'), requires_grad=False).unsqueeze(0) self.pose_prior = get_prior('male', precomputed=True) # Load smpl part labels with open( '/BS/bharat-2/work/LearntRegistration/test_data/smpl_parts_dense.pkl', 'rb') as f: dat = pkl.load(f, encoding='latin-1') self.smpl_parts = np.zeros((6890, 1)) for n, k in enumerate(dat): self.smpl_parts[dat[k]] = n self.exp_path = join(os.path.dirname(__file__), '../experiments/{}'.format(exp_name)) self.checkpoint_path = join(self.exp_path, 'checkpoints/'.format(exp_name)) if not os.path.exists(self.checkpoint_path): print(self.checkpoint_path) os.makedirs(self.checkpoint_path) self.writer = SummaryWriter( join(self.exp_path, 'summary'.format(exp_name))) self.val_min = None