def get_interpenetration_module(): from mesh_intersection.bvh_search_tree import BVH import mesh_intersection.loss as collisions_loss from mesh_intersection.filter_faces import FilterFaces max_collisions = 8 # 128 df_cone_height = 0.5 # 0.0001 point2plane = False penalize_outside = True part_segm_fn = 'smplx_parts_segm.pkl' ign_part_pairs = ["9,16", "9,17", "6,16", "6,17", "1,2", "12,22"] search_tree = BVH(max_collisions=max_collisions) pen_distance = collisions_loss.DistanceFieldPenetrationLoss( sigma=df_cone_height, point2plane=point2plane, vectorized=True, penalize_outside=penalize_outside) if part_segm_fn: part_segm_fn = os.path.expandvars(part_segm_fn) with open(part_segm_fn, 'rb') as faces_parents_file: face_segm_data = pkl.load(faces_parents_file, encoding='latin1') faces_segm = face_segm_data['segm'] faces_parents = face_segm_data['parents'] tri_filtering_module = FilterFaces( faces_segm=faces_segm, faces_parents=faces_parents, ign_part_pairs=ign_part_pairs).cuda() return search_tree, pen_distance, tri_filtering_module
def __init__(self, bm, max_collisions=8, sigma=1e-3, filter_faces=True): super(BodyInterpenetration, self).__init__() self.bm = bm self.model_type = bm.model_type nv = bm.shapedirs.shape[0] device = bm.f.device if 'cuda' not in str(device): raise NotImplementedError('Interpenetration term is only avaialble for body models on GPU.') try:import mesh_intersection except:raise('Optional package mesh_intersection is required for this functionality. Please install from https://github.com/vchoutas/torch-mesh-isect.') from mesh_intersection.bvh_search_tree import BVH from mesh_intersection.loss import DistanceFieldPenetrationLoss self.search_tree = BVH(max_collisions=max_collisions) self.pen_distance = DistanceFieldPenetrationLoss( sigma=sigma, point2plane=False, vectorized=True, penalize_outside=True) self.filter_faces = None if filter_faces: if self.model_type == 'mano': import sys sys.stderr.write('Filter faces is not available for MANO model yet!') else: #import cPickle as pickle import pickle from mesh_intersection.filter_faces import FilterFaces # ign_part_pairs: The pairs of parts where collisions will be ignored # here 1: LeftTigh, 2: RightTigh, 6:Spine1, 9:Spine2, 12:Neck, 15:Head, 16:LeftUpperArm, 17:RightUpperArm, 22:Jaw ign_part_pairs = ["9,16", "9,17", "6,16", "6,17", "1,2"] + (["12,15"] if self.model_type in ['smpl', 'smplh'] else ["12,22"]) part_segm_fname = os.path.join(os.path.dirname(__file__),'parts_segm/%s/parts_segm.pkl'%('smplh' if self.model_type in ['smpl', 'smplh'] else self.model_type)) with open(part_segm_fname, 'rb') as faces_parents_file: face_segm_data = pickle.load(faces_parents_file, encoding='latin1') faces_segm = face_segm_data['segm'] faces_parents = face_segm_data['parents'] # Create the module used to filter invalid collision pairs self.filter_faces = FilterFaces( faces_segm=faces_segm, faces_parents=faces_parents, ign_part_pairs=ign_part_pairs).to(device=device) batched_f = bm.f.clone().unsqueeze(0).repeat([bm.batch_size, 1, 1]).type(torch.long) self.faces_ids = batched_f + (torch.arange(bm.batch_size, dtype=torch.long).to(device) * nv)[:, None, None]
def highlight_self_intersections(mesh_path): ''' This function is used to create .obj file where self-intersections are highlighted mesh_path: path to the .obj object ''' mesh = trimesh.load(mesh_path) vertices = torch.tensor(mesh.vertices, dtype=torch.float32, device='cuda') faces = torch.tensor(mesh.faces.astype(np.int64), dtype=torch.long, device='cuda') batch_size = 1 triangles = vertices[faces].unsqueeze(dim=0) m = BVH(max_collisions=8) outputs = m(triangles) outputs = outputs.detach().cpu().numpy().squeeze() collisions = outputs[outputs[:, 0] >= 0, :] print('Number of collisions = ', collisions.shape[0]) print('Percentage of collisions (%)', collisions.shape[0] / float(triangles.shape[1]) * 100) recv_faces = mesh.faces[collisions[:, 0]] intr_faces = mesh.faces[collisions[:, 1]] mesh1 = trimesh.Trimesh(mesh.vertices, recv_faces) mesh2 = trimesh.Trimesh(mesh.vertices, intr_faces) inter_mesh = merge_meshes(mesh1, mesh2) fci = np.ones((2 * recv_faces.shape[0], 3)) * np.array([1, 0, 0]) fcf = np.ones((mesh.faces.shape[0], 3)) mesh = merge_meshes(inter_mesh, mesh) final_mesh = trimesh.Trimesh(vertices=mesh.vertices, faces=mesh.faces, face_colors=np.vstack((fci, fcf))) # final_mesh.export(mesh_path[:-4]+'_intersection.obj'); if collisions.shape[0] > 0: print(mesh_path) final_mesh.export('intersection.obj')
def calculate_non_manifold_face_intersection(mesh_path): ''' This function returns the scores for non-manifold faces and amount of self-intersection mesh_path: path to the .obj mesh object nv: number of vertices ne: number of edges nf: number of faces nm_faces: number of instances of non-manifold faces mesh_isect: number of instances of self-intersections (only 1 out of the two triangles is counted) ''' mesh = trimesh.load(mesh_path) f_adj = mesh.face_adjacency faces = mesh.faces fn = mesh.face_normals count = 0 for f in range(f_adj.shape[0]): if fn[f_adj[f, 0]] @ fn[f_adj[f, 1]] < 0: count += 1 vertices = torch.tensor(mesh.vertices, dtype=torch.float32, device='cuda') faces = torch.tensor(mesh.faces.astype(np.int64), dtype=torch.long, device='cuda') batch_size = 1 triangles = vertices[faces].unsqueeze(dim=0).contiguous() m = BVH(max_collisions=8) outputs = m(triangles) outputs = outputs.detach().cpu().numpy().squeeze() collisions = outputs[outputs[:, 0] >= 0, :] return mesh.vertices.shape[0], mesh.edges.shape[0], mesh.faces.shape[ 0], count, collisions.shape[0]
def iterative_solve_collision(verts_tensor, face_tensor, face_segm_tensor, edge_target, device, iter_num=10, max_collisions=8, step_size=0.5, w_data=1, w_lap=0.1, w_el=0.1): search_tree = BVH(max_collisions=max_collisions) bs, nv = verts_tensor.shape[:2] bs, nf = face_tensor.shape[:2] faces_idx = face_tensor + (torch.arange(bs, dtype=torch.long).to(device) * nv)[:, None, None] laplaceloss = LaplacianLoss(faces_idx, verts_tensor, toref=True) # detect and remove mesh interpenetration between body and garments. for i in range(iter_num): triangles = verts_tensor.view([-1, 3])[faces_idx] collision_idxs = search_tree(triangles) # (collide_num,2): bs and face index for valid collide pairs in collision_idxs val_id = collision_idxs[:, :, 0].ge(0).nonzero() # (collide_num, 2): collide pairs face id val_col_idx = collision_idxs[val_id[:, 0], val_id[:, 1]] # (collide_num,): intruder and receiver face type face_type_a = face_segm_tensor[val_col_idx[:, 0]] face_type_b = face_segm_tensor[val_col_idx[:, 1]] ''' 6: 'pant', 7: 'lleg', 8: 'rleg' 1: 'larm', 2: 'rarm 5: 'shirt' ''' # bvh store collide face in ascending order, face type also in ascending order lleg_mask = (face_type_a == 6) & (face_type_b == 7) rleg_mask = (face_type_a == 6) & (face_type_b == 8) leg_mask = (lleg_mask + rleg_mask).ge(1).to(collision_idxs.dtype) leg_mask_idx = leg_mask.nonzero().reshape(-1) leg_num = leg_mask_idx.shape[0] # body-garment has no collision pairs if leg_num <= 0: break # all valid col pairs -> body and garment valid col pairs. (store face tensor index) body_garm_col_idx = torch.zeros(leg_num, 2).long().to(device) body_garm_col_idx[0:leg_num] = val_col_idx[leg_mask_idx] # garment: 0 to index col_garment_face = face_tensor[ 0, body_garm_col_idx[:, 0]] #(filterd collide_num, 3) col_garment_verts = verts_tensor[ 0, col_garment_face] # (filterd collide num,3,3) # body: 1 to index col_body_face = face_tensor[ 0, body_garm_col_idx[:, 1]] #(filterd collide_num, 3) col_body_verts = verts_tensor[ 0, col_body_face] # (filterd collide num,3,3) # compute collision garment face normals col_garment_face_norm = getFaceNormals(col_garment_verts) # compute point2face distance from collide body verts to corresponding garment face # (filterd collide num,3) p2s = torch.sum((col_body_verts - col_garment_verts[:, 0:1]) * col_garment_face_norm.unsqueeze_(1), -1) # for every tri-tri col, only select outside body verts with smallest p2s outside_mask = p2s.gt(0).float() bound_p2s = outside_mask * p2s + (1 - outside_mask) * 100 min_val, min_ind = torch.min(bound_p2s, dim=1) col_ind = torch.arange(min_ind.shape[0]).long().to(device) out_verts_fidx = torch.cat( (col_ind.unsqueeze_(-1), min_ind.unsqueeze_(-1)), -1) # get verts index for outside body verts, (outside_vnum,) out_verts_vidx = col_body_face[out_verts_fidx[:, 0], out_verts_fidx[:, 1]] # remove dup out_verts_vidx_nodup, inverse_ind, counts = torch.unique( out_verts_vidx, return_inverse=True, return_counts=True) print("outside verts num: {}".format(out_verts_vidx_nodup.shape[0])) ''' compute offset ''' # get p2s for outside body verts out_verts_p2s = p2s[out_verts_fidx[:, 0], out_verts_fidx[:, 1]] offset = torch.zeros(out_verts_vidx_nodup.shape).float().to(device) offset.index_add_(0, inverse_ind, out_verts_p2s) offset = offset / counts.float() ''' compute direction ''' # get corresponding garment face norm for outside body verts out_face_norm = col_garment_face_norm[out_verts_fidx[:, 0]] direction = torch.zeros(out_verts_vidx_nodup.shape[0], 3).float().to(device) direction.index_add_(0, inverse_ind, out_face_norm) direction = direction / counts.unsqueeze_(-1).float() verts_tensor[ 0, out_verts_vidx_nodup] -= step_size * direction * offset.unsqueeze_( -1) ''' optimize non-detected vertices ''' # get non-detected vertices id all_vid = np.arange(verts_tensor.shape[1]).tolist() out_verts_vidx_nodup_list = out_verts_vidx_nodup.cpu().numpy().tolist() optim_vid = [i for i in all_vid if i not in out_verts_vidx_nodup_list] optim_vid = torch.Tensor(optim_vid).long().to(device) # prepare params and optimizer pred_verts = verts_tensor.clone() params = pred_verts[:, optim_vid].requires_grad_() optimizer = optim.LBFGS([params], lr=0.001, line_search_fn='strong_wolfe', max_iter=20) # run optimization for i in range(5): def closure(): optimizer.zero_grad() pred_verts = verts_tensor.clone() pred_verts[:, optim_vid] = params lap_loss = laplaceloss(pred_verts) edge_loss = compute_edge_loss(pred_verts, faces_idx[0], edge_target) data_loss = F.mse_loss(pred_verts, verts_tensor) loss = w_data * data_loss + w_lap * lap_loss + w_el * edge_loss loss.backward() return loss optimizer.step(closure)
def run(model, ds, niter, args, epoch_start=0, use_adam=True): use_collision = True if args.data[-1] == "A" or ds.nmaterials == 2: use_collision = False if use_collision: from mesh_intersection.bvh_search_tree import BVH search_tree = BVH(max_collisions=16) #pen_distance = DistanceFieldPenetrationLoss(sigma=0.5) print("@ model.mus", model.mus) params = get_params(model) if use_adam == True: opt = torch.optim.Adam(params, args.lr, betas=(0.9, 0.99)) else: opt = torch.optim.SGD(params, lr=args.lr, momentum=0.9, nesterov=True) # loop = tqdm.tqdm(list(range(0,args.niter))) if epoch_start == 0: f = open(args.dresult + 'log.txt', 'w') else: f = open(args.dresult + 'log.txt', 'a') log = f"@ statistics of mesh: {model.vertices.shape[0]}, {model.faces.shape[0]}\n" # full-batch case if args.b == 0: idx_angles_full = torch.LongTensor(np.arange(ds.nangles)) p_full = ds.p.cuda() ds_loader = [[idx_angles_full, p_full]] # mini-batch case else: ds_loader = torch.utils.data.DataLoader(ds, batch_size=args.b, shuffle=True, num_workers=0, drop_last=True, pin_memory=True) mask_bg = ds.p < 1e-5 mask_bg = mask_bg.cuda() use_silhouette = False # if use_silhouette: # mask_bg = ds.p < 1e-5 # mask_bg = mask_bg.cuda() # mask_bg = (p_batch > p_full.min()+0.05) #mask_bg = 1 ledge = 0 llap = 0. lflat = 0. for epoch in range(epoch_start, niter): if epoch + 100 == niter and niter > 400: args.lr *= 0.5 print("@ args.lr", args.lr) # if epoch % 20 == 0 or epoch == niter-1: start = time.time() for idx_angles, p_batch in ds_loader: displace_prev = model.displace.data.clone() if args.b > 0: p_batch = p_batch.cuda() opt.zero_grad() phat, mask_valid, edge_loss, lap_loss, flat_loss = model( idx_angles, args.wedge) # full angles # phat[~mask_valid] = 0.0 # mask_valid = mask_valid + mask_bg if 1: # l2 loss data_loss = (p_batch - phat)[mask_valid].pow(2).mean() if use_silhouette: idx_bg = (~mask_valid) * mask_bg[idx_angles] nbg = torch.sum(idx_bg) if nbg: print("sum(idx_bg), min, max", nbg, torch.min(phat[idx_bg]).item(), torch.max(phat[idx_bg]).item()) # print(phat[idx_bg]) data_loss += (phat[idx_bg]).pow(2).mean() # data_loss += 10 * torch.abs(phat[idx_bg]).mean() # add for the invalid pixels # data_loss += (p_batch)[~mask_valid].pow(2).mean() else: # student t misfit sigma = 1 data_loss = torch.log(1 + ( (p_batch - phat)[mask_valid] / sigma)**2).mean() loss = data_loss + args.wedge * edge_loss + args.wlap * lap_loss + args.wflat * flat_loss loss.backward() opt.step() loss_now = loss.item() model.mus.data.clamp_(min=0.0) if use_collision == False: continue # if epoch % 20 == 0 or epoch == niter-1: elpased_time = time.time() - start if epoch > epoch_start and args.b == 0: dloss = (loss_prev - loss_now) # should be positive if dloss < 1e-11 and dloss > 0: print('! converged') break loss_prev = loss_now if args.wedge > 0.: ledge = edge_loss.item() if args.wlap > 0.: llap = lap_loss.item() if args.wflat > 0.: lflat = flat_loss.item() log += f'~ {epoch:03d} l2_loss: {data_loss.item():.8f} edge: {ledge:.6f} lap: {llap:.6f} flat: {lflat:.6f} mus: {str(model.mus.cpu().detach().numpy())} time: {elpased_time:.4f}\n' #log += f'center: {model.center[0,0].item():.6f} {model.center[0,1].item():.6f} {model.center[0,2].item():.6f}' # f.write(log+"\n") if epoch % 60 == 0 or epoch == niter - 1: if torch.sum(~mask_valid) > 15000 and epoch > 100: assert 0, "consider increasing regularization" print(log) if args.b == 0: res_np = ds.p_original - phat.detach().cpu().numpy() res_scalar = np.mean(res_np**2) f.write(f"~ res_np: {res_scalar}\n") util_vis.save_sino_as_img(args.dresult + f'{epoch:04d}_sino_res.png', res_np, cmap='coolwarm') phat[~mask_valid] = 0. print(phat.min(), phat.max()) if args.verbose == 0: continue vv = model.vertices.cpu() + model.displace.detach().cpu() ff = model.faces.cpu() labels_v, labels_f = model.labels_v_np, model.labels.cpu().numpy() # util_vis.save_vf_as_img_labels(args.dresult+f'{epoch:04d}_render.png', vv, ff, labels_v, labels_f) util_vis.save_sino_as_img(args.dresult + f'{epoch:04d}_sino.png', phat.detach().cpu().numpy()) util_mesh.save_mesh(args.dresult + f'{epoch:04d}.obj', vv.numpy(), ff.numpy(), labels_v, labels_f) if args.data == "3nanoC": import subprocess subprocess.Popen([ "python", "../plot/compute_volume.py", args.dresult + f'{epoch:04d}.obj' ]) if epoch == niter - 1: util_mesh.save_mesh(args.dresult + 'mesh.obj', vv.numpy(), ff.numpy(), labels_v, labels_f) util_vis.save_sino_as_img( args.dresult + f'{epoch:04d}_data.png', ds.p.cuda()) f.write(log + "\n")
# sub_p3d_mesh = subdivide(p3d_mesh) # final_verts, final_faces = sub_p3d_mesh.get_mesh_verts_faces(0) # save_obj('./subdivided_helge_scaled.obj', final_verts, final_faces) print('Number of triangles = ', input_mesh.faces.shape[0]) vertices = torch.tensor(input_mesh.vertices, dtype=torch.float32, device=device) faces = torch.tensor(input_mesh.faces.astype(np.int64), dtype=torch.long, device=device) batch_size = 1 triangles = vertices[faces].unsqueeze(dim=0) m = BVH(max_collisions=max_collisions) torch.cuda.synchronize() start = time.time() outputs = m(triangles) torch.cuda.synchronize() print('Elapsed time', time.time() - start) outputs = outputs.detach().cpu().numpy().squeeze() collisions = outputs[outputs[:, 0] >= 0, :] print(collisions.shape) print('Number of collisions = ', collisions.shape[0]) print('Percentage of collisions (%)',
def main(): description = 'Example script for untangling SMPL self intersections' parser = argparse.ArgumentParser(description=description, prog='Batch SMPL-Untangle') parser.add_argument('--param_fn', type=str, nargs='*', required=True, help='The pickle file with the model parameters') parser.add_argument('--interactive', default=True, type=lambda arg: arg.lower() in ['true', '1'], help='Display the mesh during the optimization' + ' process') parser.add_argument('--delay', type=int, default=50, help='The delay for the animation callback in ms') parser.add_argument('--model_folder', type=str, default='models', help='The path to the LBS model') parser.add_argument('--model_type', type=str, default='smpl', choices=['smpl', 'smplx', 'smplh'], help='The type of model to create') parser.add_argument('--point2plane', default=False, type=lambda arg: arg.lower() in ['true', '1'], help='Use point to distance') parser.add_argument('--optimize_pose', default=True, type=lambda arg: arg.lower() in ['true', '1'], help='Enable optimization over the joint pose') parser.add_argument('--optimize_shape', default=False, type=lambda arg: arg.lower() in ['true', '1'], help='Enable optimization over the shape of the model') parser.add_argument('--sigma', default=0.5, type=float, help='The height of the cone used to calculate the' + ' distance field loss') parser.add_argument('--lr', default=1, type=float, help='The learning rate for SGD') parser.add_argument('--coll_loss_weight', default=1e-4, type=float, help='The weight for the collision loss') parser.add_argument('--pose_reg_weight', default=0, type=float, help='The weight for the pose regularizer') parser.add_argument('--shape_reg_weight', default=0, type=float, help='The weight for the shape regularizer') parser.add_argument('--max_collisions', default=8, type=int, help='The maximum number of bounding box collisions') parser.add_argument('--part_segm_fn', default='', type=str, help='The file with the part segmentation for the' + ' faces of the model') parser.add_argument('--print_timings', default=False, type=lambda arg: arg.lower() in ['true', '1'], help='Print timings for all the operations') args = parser.parse_args() model_folder = args.model_folder model_type = args.model_type param_fn = args.param_fn interactive = args.interactive delay = args.delay point2plane = args.point2plane # optimize_shape = args.optimize_shape # optimize_pose = args.optimize_pose lr = args.lr coll_loss_weight = args.coll_loss_weight pose_reg_weight = args.pose_reg_weight shape_reg_weight = args.shape_reg_weight max_collisions = args.max_collisions sigma = args.sigma part_segm_fn = args.part_segm_fn print_timings = args.print_timings if interactive: import trimesh import pyrender device = torch.device('cuda') batch_size = len(param_fn) params_dict = defaultdict(lambda: []) for idx, fn in enumerate(param_fn): with open(fn, 'rb') as param_file: data = pickle.load(param_file, encoding='latin1') assert 'betas' in data, \ 'No key for shape parameter in provided pickle file' assert 'global_pose' in data, \ 'No key for the global pose in the given pickle file' assert 'pose' in data, \ 'No key for the pose of the joints in the given pickle file' for key, val in data.items(): params_dict[key].append(val) params = {} for key in params_dict: params[key] = np.stack(params_dict[key], axis=0).astype(np.float32) if len(params[key].shape) < 2: params[key] = params[key][np.newaxis] if 'global_pose' in params: params['global_orient'] = params['global_pose'] if 'pose' in params: params['body_pose'] = params['pose'] if part_segm_fn: # Read the part segmentation with open(part_segm_fn, 'rb') as faces_parents_file: data = pickle.load(faces_parents_file, encoding='latin1') faces_segm = data['segm'] faces_parents = data['parents'] # Create the module used to filter invalid collision pairs filter_faces = FilterFaces(faces_segm, faces_parents).to(device=device) # Create the body model body = create(model_folder, batch_size=batch_size, model_type=model_type).to(device=device) body.reset_params(**params) # Clone the given pose to use it as a target for regularization init_pose = body.body_pose.clone().detach() # Create the search tree search_tree = BVH(max_collisions=max_collisions) pen_distance = \ collisions_loss.DistanceFieldPenetrationLoss(sigma=sigma, point2plane=point2plane, vectorized=True) mse_loss = nn.MSELoss(reduction='sum').to(device=device) face_tensor = torch.tensor(body.faces.astype(np.int64), dtype=torch.long, device=device).unsqueeze_(0).repeat([batch_size, 1, 1]) with torch.no_grad(): output = body(get_skin=True) verts = output.vertices bs, nv = verts.shape[:2] bs, nf = face_tensor.shape[:2] faces_idx = face_tensor + \ (torch.arange(bs, dtype=torch.long).to(device) * nv)[:, None, None] optimizer = torch.optim.SGD([body.body_pose], lr=lr) if interactive: # Plot the initial mesh with torch.no_grad(): output = body(get_skin=True) verts = output.vertices np_verts = verts.detach().cpu().numpy() def create_mesh(vertices, faces, color=(0.3, 0.3, 0.3, 1.0), wireframe=False): tri_mesh = trimesh.Trimesh(vertices=vertices, faces=faces) rot = trimesh.transformations.rotation_matrix(np.radians(180), [1, 0, 0]) tri_mesh.apply_transform(rot) material = pyrender.MetallicRoughnessMaterial( metallicFactor=0.0, alphaMode='BLEND', baseColorFactor=color) return pyrender.Mesh.from_trimesh( tri_mesh, material=material) scene = pyrender.Scene(bg_color=[0.0, 0.0, 0.0, 1.0], ambient_light=(1.0, 1.0, 1.0)) for bidx in range(np_verts.shape[0]): curr_verts = np_verts[bidx].copy() body_mesh = create_mesh(curr_verts, body.faces, color=(0.3, 0.3, 0.3, 0.99), wireframe=True) pose = np.eye(4) pose[0, 3] = bidx * 2 scene.add(body_mesh, name='body_mesh_{:03d}'.format(bidx), pose=pose) viewer = pyrender.Viewer(scene, use_raymond_lighting=True, viewport_size=(1200, 800), cull_faces=False, run_in_thread=True) query_names = ['recv_mesh', 'intr_mesh', 'body_mesh'] step = 0 while True: optimizer.zero_grad() if print_timings: start = time.time() if print_timings: torch.cuda.synchronize() output = body(get_skin=True) verts = output.vertices if print_timings: torch.cuda.synchronize() print('Body model forward: {:5f}'.format(time.time() - start)) if print_timings: torch.cuda.synchronize() start = time.time() triangles = verts.view([-1, 3])[faces_idx] if print_timings: torch.cuda.synchronize() print('Triangle indexing: {:5f}'.format(time.time() - start)) with torch.no_grad(): if print_timings: start = time.time() collision_idxs = search_tree(triangles) if print_timings: torch.cuda.synchronize() print('Collision Detection: {:5f}'.format(time.time() - start)) if part_segm_fn: if print_timings: start = time.time() collision_idxs = filter_faces(collision_idxs) if print_timings: torch.cuda.synchronize() print('Collision filtering: {:5f}'.format(time.time() - start)) if print_timings: start = time.time() pen_loss = coll_loss_weight * \ pen_distance(triangles, collision_idxs) if print_timings: torch.cuda.synchronize() print('Penetration loss: {:5f}'.format(time.time() - start)) shape_reg_loss = torch.tensor(0, device=device, dtype=torch.float32) if shape_reg_weight > 0: shape_reg_loss = shape_reg_weight * torch.sum(body.betas ** 2) pose_reg_loss = torch.tensor(0, device=device, dtype=torch.float32) if pose_reg_weight > 0: pose_reg_loss = pose_reg_weight * \ mse_loss(body.pose, init_pose) loss = pen_loss + pose_reg_loss + shape_reg_loss np_loss = loss.detach().cpu().squeeze().tolist() if type(np_loss) != list: np_loss = [np_loss] msg = '{:.5f} ' * len(np_loss) print('Loss per model:', msg.format(*np_loss)) if print_timings: start = time.time() loss.backward(torch.ones_like(loss)) if print_timings: torch.cuda.synchronize() print('Backward pass: {:5f}'.format(time.time() - start)) if interactive: with torch.no_grad(): output = body(get_skin=True) verts = output.vertices np_verts = verts.detach().cpu().numpy() np_collision_idxs = collision_idxs.detach().cpu().numpy() np_receivers = np_collision_idxs[:, :, 0] np_intruders = np_collision_idxs[:, :, 1] viewer.render_lock.acquire() for node in scene.get_nodes(): if node.name is None: continue if any([query in node.name for query in query_names]): scene.remove_node(node) for bidx in range(batch_size): recv_faces_idxs = np_receivers[bidx][np_receivers[bidx] >= 0] intr_faces_idxs = np_intruders[bidx][np_intruders[bidx] >= 0] recv_faces = body.faces[recv_faces_idxs] intr_faces = body.faces[intr_faces_idxs] curr_verts = np_verts[bidx].copy() body_mesh = create_mesh(curr_verts, body.faces, color=(0.3, 0.3, 0.3, 0.99), wireframe=True) pose = np.eye(4) pose[0, 3] = bidx * 2 scene.add(body_mesh, name='body_mesh_{:03d}'.format(bidx), pose=pose) if len(intr_faces) > 0: intr_mesh = create_mesh(curr_verts, intr_faces, color=(0.9, 0.0, 0.0, 1.0)) scene.add(intr_mesh, name='intr_mesh_{:03d}'.format(bidx), pose=pose) if len(recv_faces) > 0: recv_mesh = create_mesh(curr_verts, recv_faces, color=(0.0, 0.9, 0.0, 1.0)) scene.add(recv_mesh, name='recv_mesh_{:03d}'.format(bidx), pose=pose) viewer.render_lock.release() if not viewer.is_active: break time.sleep(delay / 1000) optimizer.step() step += 1
def non_linear_solver(setting, data, batch_size=1, data_weights=None, body_pose_prior_weights=None, shape_weights=None, coll_loss_weights=None, use_joints_conf=False, use_3d=False, rho=100, interpenetration=False, loss_type='smplify', visualize=False, use_vposer=True, interactive=True, use_cuda=True, is_seq=False, **kwargs): assert batch_size == 1, 'PyTorch L-BFGS only supports batch_size == 1' views = setting['views'] device = setting['device'] dtype = setting['dtype'] vposer = setting['vposer'] keypoints = data['keypoints'] joint_weights = setting['joints_weight'] model = setting['model'] camera = setting['camera'] pose_embedding = setting['pose_embedding'] seq_start = setting['seq_start'] assert (len(data_weights) == len(body_pose_prior_weights) and len(shape_weights) == len(body_pose_prior_weights) and len(coll_loss_weights) == len(body_pose_prior_weights)), "Number of weight must match" # process keypoints keypoint_data = torch.tensor(keypoints, dtype=dtype) gt_joints = keypoint_data[:, :, :, :2] if use_joints_conf: joints_conf = [] for v in keypoint_data: conf = v[:, :, 2].reshape(1, -1) conf = conf.to(device=device, dtype=dtype) joints_conf.append(conf) if use_3d: joints_data = torch.tensor(joints3d, dtype=dtype) gt_joints3d = joints_data[:, :3] if use_joints_conf: joints3d_conf = joints_data[:, 3].reshape(1, -1).to(device=device, dtype=dtype) if not use_hip: joints3d_conf[0][11] = 0 joints3d_conf[0][12] = 0 gt_joints3d = gt_joints3d.to(device=device, dtype=dtype) else: gt_joints3d = None joints3d_conf = None # Transfer the data to the correct device gt_joints = gt_joints.to(device=device, dtype=dtype) # Create the search tree search_tree = None pen_distance = None filter_faces = None # we do not use this term at this time if interpenetration: from mesh_intersection.bvh_search_tree import BVH import mesh_intersection.loss as collisions_loss from mesh_intersection.filter_faces import FilterFaces assert use_cuda, 'Interpenetration term can only be used with CUDA' assert torch.cuda.is_available(), \ 'No CUDA Device! Interpenetration term can only be used' + \ ' with CUDA' search_tree = BVH(max_collisions=max_collisions) pen_distance = \ collisions_loss.DistanceFieldPenetrationLoss( sigma=df_cone_height, point2plane=point2plane, vectorized=True, penalize_outside=penalize_outside) if part_segm_fn: # Read the part segmentation part_segm_fn = os.path.expandvars(part_segm_fn) with open(part_segm_fn, 'rb') as faces_parents_file: face_segm_data = pickle.load(faces_parents_file, encoding='latin1') faces_segm = face_segm_data['segm'] faces_parents = face_segm_data['parents'] # Create the module used to filter invalid collision pairs filter_faces = FilterFaces( faces_segm=faces_segm, faces_parents=faces_parents, ign_part_pairs=ign_part_pairs).to(device=device) # Weights used for the pose prior and the shape prior opt_weights_dict = { 'data_weight': data_weights, 'body_pose_weight': body_pose_prior_weights, 'shape_weight': shape_weights } if interpenetration: opt_weights_dict['coll_loss_weight'] = coll_loss_weights # get weights for each stage keys = opt_weights_dict.keys() opt_weights = [ dict(zip(keys, vals)) for vals in zip(*(opt_weights_dict[k] for k in keys if opt_weights_dict[k] is not None)) ] for weight_list in opt_weights: for key in weight_list: weight_list[key] = torch.tensor(weight_list[key], device=device, dtype=dtype) # create fitting loss loss = fitting.create_loss(loss_type=loss_type, joint_weights=joint_weights, rho=rho, use_joints_conf=use_joints_conf, vposer=vposer, pose_embedding=pose_embedding, body_pose_prior=setting['body_pose_prior'], shape_prior=setting['shape_prior'], angle_prior=setting['angle_prior'], interpenetration=interpenetration, pen_distance=pen_distance, search_tree=search_tree, tri_filtering_module=filter_faces, dtype=dtype, use_3d=use_3d, **kwargs) loss = loss.to(device=device) monitor = fitting.FittingMonitor(batch_size=batch_size, visualize=visualize, **kwargs) # with fitting.FittingMonitor( # batch_size=batch_size, visualize=visualize, **kwargs) as monitor: H, W, _ = data['img'][0].shape data_weight = 500 / H # Reset the parameters to estimate the initial translation of the # body model # body_model.reset_params(body_pose=body_mean_pose, transl=init['init_t'], global_orient=init['init_r'], scale=init['init_s'], betas=init['init_betas']) # we do not change rotation in multi-view task orientations = [model.global_orient] # store here the final error for both orientations, # and pick the orientation resulting in the lowest error results = [] # Step 1: Optimize the full model final_loss_val = 0 opt_start = time.time() # # initial value for non-linear solve # new_params = defaultdict(global_orient=model.global_orient, # # body_pose=body_mean_pose, # transl=model.transl, # scale=model.scale, # betas=model.betas, # ) # if vposer is not None: # with torch.no_grad(): # pose_embedding.fill_(0) # model.reset_params(**new_params) for opt_idx, curr_weights in enumerate(tqdm(opt_weights, desc='Stage')): # pass stage1 and stage2 if it is a sequence if not seq_start and is_seq: if opt_idx < 2: continue elif opt_idx == 2: curr_weights['body_pose_weight'] *= 0.15 body_params = list(model.parameters()) final_params = list(filter(lambda x: x.requires_grad, body_params)) if vposer is not None: final_params.append(pose_embedding) body_optimizer, body_create_graph = optim_factory.create_optimizer( final_params, **kwargs) body_optimizer.zero_grad() curr_weights['data_weight'] = data_weight curr_weights['bending_prior_weight'] = ( 3.17 * curr_weights['body_pose_weight']) loss.reset_loss_weights(curr_weights) closure = monitor.create_fitting_closure( body_optimizer, model, camera=camera, gt_joints=gt_joints, joints_conf=joints_conf, gt_joints3d=gt_joints3d, joints3d_conf=joints3d_conf, joint_weights=joint_weights, loss=loss, create_graph=body_create_graph, use_vposer=use_vposer, vposer=vposer, pose_embedding=pose_embedding, return_verts=True, return_full_pose=True, use_3d=use_3d) if interactive: if use_cuda and torch.cuda.is_available(): torch.cuda.synchronize() stage_start = time.time() final_loss_val = monitor.run_fitting(body_optimizer, closure, final_params, model, pose_embedding=pose_embedding, vposer=vposer, use_vposer=use_vposer) if interactive: if use_cuda and torch.cuda.is_available(): torch.cuda.synchronize() elapsed = time.time() - stage_start if interactive: tqdm.write('Stage {:03d} done after {:.4f} seconds'.format( opt_idx, elapsed)) if interactive: if use_cuda and torch.cuda.is_available(): torch.cuda.synchronize() elapsed = time.time() - opt_start tqdm.write('Body fitting done after {:.4f} seconds'.format(elapsed)) tqdm.write('Body final loss val = {:.5f}'.format(final_loss_val)) # Get the result of the fitting process result = { key: val.detach().cpu().numpy() for key, val in model.named_parameters() } result['loss'] = final_loss_val result['pose_embedding'] = pose_embedding return result
def fit_single_frame( img, keypoints, init_trans, scan, scene_name, body_model, camera, joint_weights, body_pose_prior, jaw_prior, left_hand_prior, right_hand_prior, shape_prior, expr_prior, angle_prior, result_fn='out.pkl', mesh_fn='out.obj', body_scene_rendering_fn='body_scene.png', out_img_fn='overlay.png', loss_type='smplify', use_cuda=True, init_joints_idxs=(9, 12, 2, 5), use_face=True, use_hands=True, data_weights=None, body_pose_prior_weights=None, hand_pose_prior_weights=None, jaw_pose_prior_weights=None, shape_weights=None, expr_weights=None, hand_joints_weights=None, face_joints_weights=None, depth_loss_weight=1e2, interpenetration=True, coll_loss_weights=None, df_cone_height=0.5, penalize_outside=True, max_collisions=8, point2plane=False, part_segm_fn='', focal_length_x=5000., focal_length_y=5000., side_view_thsh=25., rho=100, vposer_latent_dim=32, vposer_ckpt='', use_joints_conf=False, interactive=True, visualize=False, save_meshes=True, degrees=None, batch_size=1, dtype=torch.float32, ign_part_pairs=None, left_shoulder_idx=2, right_shoulder_idx=5, #################### ### PROX render_results=True, camera_mode='moving', ## Depth s2m=False, s2m_weights=None, m2s=False, m2s_weights=None, rho_s2m=1, rho_m2s=1, init_mode=None, trans_opt_stages=None, viz_mode='mv', #penetration sdf_penetration=False, sdf_penetration_weights=0.0, sdf_dir=None, cam2world_dir=None, #contact contact=False, rho_contact=1.0, contact_loss_weights=None, contact_angle=15, contact_body_parts=None, body_segments_dir=None, load_scene=False, scene_dir=None, **kwargs): assert batch_size == 1, 'PyTorch L-BFGS only supports batch_size == 1' body_model.reset_params() body_model.transl.requires_grad = True device = torch.device('cuda') if use_cuda else torch.device('cpu') if visualize: pil_img.fromarray((img * 255).astype(np.uint8)).show() if degrees is None: degrees = [0, 90, 180, 270] if data_weights is None: data_weights = [ 1, ] * 5 if body_pose_prior_weights is None: body_pose_prior_weights = [4.04 * 1e2, 4.04 * 1e2, 57.4, 4.78] msg = ('Number of Body pose prior weights {}'.format( len(body_pose_prior_weights)) + ' does not match the number of data term weights {}'.format( len(data_weights))) assert (len(data_weights) == len(body_pose_prior_weights)), msg if use_hands: if hand_pose_prior_weights is None: hand_pose_prior_weights = [1e2, 5 * 1e1, 1e1, .5 * 1e1] msg = ('Number of Body pose prior weights does not match the' + ' number of hand pose prior weights') assert ( len(hand_pose_prior_weights) == len(body_pose_prior_weights)), msg if hand_joints_weights is None: hand_joints_weights = [0.0, 0.0, 0.0, 1.0] msg = ('Number of Body pose prior weights does not match the' + ' number of hand joint distance weights') assert ( len(hand_joints_weights) == len(body_pose_prior_weights)), msg if shape_weights is None: shape_weights = [1e2, 5 * 1e1, 1e1, .5 * 1e1] msg = ('Number of Body pose prior weights = {} does not match the' + ' number of Shape prior weights = {}') assert (len(shape_weights) == len(body_pose_prior_weights)), msg.format( len(shape_weights), len(body_pose_prior_weights)) if use_face: if jaw_pose_prior_weights is None: jaw_pose_prior_weights = [[x] * 3 for x in shape_weights] else: jaw_pose_prior_weights = map(lambda x: map(float, x.split(',')), jaw_pose_prior_weights) jaw_pose_prior_weights = [list(w) for w in jaw_pose_prior_weights] msg = ('Number of Body pose prior weights does not match the' + ' number of jaw pose prior weights') assert ( len(jaw_pose_prior_weights) == len(body_pose_prior_weights)), msg if expr_weights is None: expr_weights = [1e2, 5 * 1e1, 1e1, .5 * 1e1] msg = ('Number of Body pose prior weights = {} does not match the' + ' number of Expression prior weights = {}') assert (len(expr_weights) == len(body_pose_prior_weights)), msg.format( len(body_pose_prior_weights), len(expr_weights)) if face_joints_weights is None: face_joints_weights = [0.0, 0.0, 0.0, 1.0] msg = ('Number of Body pose prior weights does not match the' + ' number of face joint distance weights') assert (len(face_joints_weights) == len(body_pose_prior_weights)), msg if coll_loss_weights is None: coll_loss_weights = [0.0] * len(body_pose_prior_weights) msg = ('Number of Body pose prior weights does not match the' + ' number of collision loss weights') assert (len(coll_loss_weights) == len(body_pose_prior_weights)), msg use_vposer = kwargs.get('use_vposer', True) vposer, pose_embedding = [ None, ] * 2 if use_vposer: pose_embedding = torch.zeros([batch_size, 32], dtype=dtype, device=device, requires_grad=True) vposer_ckpt = osp.expandvars(vposer_ckpt) vposer, _ = load_vposer(vposer_ckpt, vp_model='snapshot') vposer = vposer.to(device=device) vposer.eval() if use_vposer: body_mean_pose = torch.zeros([batch_size, vposer_latent_dim], dtype=dtype) else: body_mean_pose = body_pose_prior.get_mean().detach().cpu() keypoint_data = torch.tensor(keypoints, dtype=dtype) gt_joints = keypoint_data[:, :, :2] if use_joints_conf: joints_conf = keypoint_data[:, :, 2].reshape(1, -1) # Transfer the data to the correct device gt_joints = gt_joints.to(device=device, dtype=dtype) if use_joints_conf: joints_conf = joints_conf.to(device=device, dtype=dtype) scan_tensor = None if scan is not None: scan_tensor = torch.tensor(scan.get('points'), device=device, dtype=dtype).unsqueeze(0) # load pre-computed signed distance field sdf = None sdf_normals = None grid_min = None grid_max = None voxel_size = None if sdf_penetration: with open(osp.join(sdf_dir, scene_name + '.json'), 'r') as f: sdf_data = json.load(f) grid_min = torch.tensor(np.array(sdf_data['min']), dtype=dtype, device=device) grid_max = torch.tensor(np.array(sdf_data['max']), dtype=dtype, device=device) grid_dim = sdf_data['dim'] voxel_size = (grid_max - grid_min) / grid_dim sdf = np.load(osp.join(sdf_dir, scene_name + '_sdf.npy')).reshape( grid_dim, grid_dim, grid_dim) sdf = torch.tensor(sdf, dtype=dtype, device=device) if osp.exists(osp.join(sdf_dir, scene_name + '_normals.npy')): sdf_normals = np.load( osp.join(sdf_dir, scene_name + '_normals.npy')).reshape( grid_dim, grid_dim, grid_dim, 3) sdf_normals = torch.tensor(sdf_normals, dtype=dtype, device=device) else: print("Normals not found...") with open(os.path.join(cam2world_dir, scene_name + '.json'), 'r') as f: cam2world = np.array(json.load(f)) R = torch.tensor(cam2world[:3, :3].reshape(3, 3), dtype=dtype, device=device) t = torch.tensor(cam2world[:3, 3].reshape(1, 3), dtype=dtype, device=device) # Create the search tree search_tree = None pen_distance = None filter_faces = None if interpenetration: from mesh_intersection.bvh_search_tree import BVH import mesh_intersection.loss as collisions_loss from mesh_intersection.filter_faces import FilterFaces assert use_cuda, 'Interpenetration term can only be used with CUDA' assert torch.cuda.is_available(), \ 'No CUDA Device! Interpenetration term can only be used' + \ ' with CUDA' search_tree = BVH(max_collisions=max_collisions) pen_distance = \ collisions_loss.DistanceFieldPenetrationLoss( sigma=df_cone_height, point2plane=point2plane, vectorized=True, penalize_outside=penalize_outside) if part_segm_fn: # Read the part segmentation part_segm_fn = os.path.expandvars(part_segm_fn) with open(part_segm_fn, 'rb') as faces_parents_file: face_segm_data = pickle.load(faces_parents_file, encoding='latin1') faces_segm = face_segm_data['segm'] faces_parents = face_segm_data['parents'] # Create the module used to filter invalid collision pairs filter_faces = FilterFaces( faces_segm=faces_segm, faces_parents=faces_parents, ign_part_pairs=ign_part_pairs).to(device=device) # load vertix ids of contact parts contact_verts_ids = ftov = None if contact: contact_verts_ids = [] for part in contact_body_parts: with open(os.path.join(body_segments_dir, part + '.json'), 'r') as f: data = json.load(f) contact_verts_ids.append(list(set(data["verts_ind"]))) contact_verts_ids = np.concatenate(contact_verts_ids) vertices = body_model(return_verts=True, body_pose=torch.zeros((batch_size, 63), dtype=dtype, device=device)).vertices vertices_np = vertices.detach().cpu().numpy().squeeze() body_faces_np = body_model.faces_tensor.detach().cpu().numpy().reshape( -1, 3) m = Mesh(v=vertices_np, f=body_faces_np) ftov = m.faces_by_vertex(as_sparse_matrix=True) ftov = sparse.coo_matrix(ftov) indices = torch.LongTensor(np.vstack((ftov.row, ftov.col))).to(device) values = torch.FloatTensor(ftov.data).to(device) shape = ftov.shape ftov = torch.sparse.FloatTensor(indices, values, torch.Size(shape)) # Read the scene scan if any scene_v = scene_vn = scene_f = None if scene_name is not None: if load_scene: scene = Mesh(filename=os.path.join(scene_dir, scene_name + '.ply')) scene.vn = scene.estimate_vertex_normals() scene_v = torch.tensor(scene.v[np.newaxis, :], dtype=dtype, device=device).contiguous() scene_vn = torch.tensor(scene.vn[np.newaxis, :], dtype=dtype, device=device) scene_f = torch.tensor(scene.f.astype(int)[np.newaxis, :], dtype=torch.long, device=device) # Weights used for the pose prior and the shape prior opt_weights_dict = { 'data_weight': data_weights, 'body_pose_weight': body_pose_prior_weights, 'shape_weight': shape_weights } if use_face: opt_weights_dict['face_weight'] = face_joints_weights opt_weights_dict['expr_prior_weight'] = expr_weights opt_weights_dict['jaw_prior_weight'] = jaw_pose_prior_weights if use_hands: opt_weights_dict['hand_weight'] = hand_joints_weights opt_weights_dict['hand_prior_weight'] = hand_pose_prior_weights if interpenetration: opt_weights_dict['coll_loss_weight'] = coll_loss_weights if s2m: opt_weights_dict['s2m_weight'] = s2m_weights if m2s: opt_weights_dict['m2s_weight'] = m2s_weights if sdf_penetration: opt_weights_dict['sdf_penetration_weight'] = sdf_penetration_weights if contact: opt_weights_dict['contact_loss_weight'] = contact_loss_weights keys = opt_weights_dict.keys() opt_weights = [ dict(zip(keys, vals)) for vals in zip(*(opt_weights_dict[k] for k in keys if opt_weights_dict[k] is not None)) ] for weight_list in opt_weights: for key in weight_list: weight_list[key] = torch.tensor(weight_list[key], device=device, dtype=dtype) # load indices of the head of smpl-x model with open(osp.join(body_segments_dir, 'body_mask.json'), 'r') as fp: head_indx = np.array(json.load(fp)) N = body_model.get_num_verts() body_indx = np.setdiff1d(np.arange(N), head_indx) head_mask = np.in1d(np.arange(N), head_indx) body_mask = np.in1d(np.arange(N), body_indx) # The indices of the joints used for the initialization of the camera init_joints_idxs = torch.tensor(init_joints_idxs, device=device) edge_indices = kwargs.get('body_tri_idxs') # which initialization mode to choose: similar traingles, mean of the scan or the average of both if init_mode == 'scan': init_t = init_trans elif init_mode == 'both': init_t = (init_trans.to(device) + fitting.guess_init( body_model, gt_joints, edge_indices, use_vposer=use_vposer, vposer=vposer, pose_embedding=pose_embedding, model_type=kwargs.get('model_type', 'smpl'), focal_length=focal_length_x, dtype=dtype)) / 2.0 else: init_t = fitting.guess_init(body_model, gt_joints, edge_indices, use_vposer=use_vposer, vposer=vposer, pose_embedding=pose_embedding, model_type=kwargs.get( 'model_type', 'smpl'), focal_length=focal_length_x, dtype=dtype) camera_loss = fitting.create_loss('camera_init', trans_estimation=init_t, init_joints_idxs=init_joints_idxs, depth_loss_weight=depth_loss_weight, camera_mode=camera_mode, dtype=dtype).to(device=device) camera_loss.trans_estimation[:] = init_t loss = fitting.create_loss(loss_type=loss_type, joint_weights=joint_weights, rho=rho, use_joints_conf=use_joints_conf, use_face=use_face, use_hands=use_hands, vposer=vposer, pose_embedding=pose_embedding, body_pose_prior=body_pose_prior, shape_prior=shape_prior, angle_prior=angle_prior, expr_prior=expr_prior, left_hand_prior=left_hand_prior, right_hand_prior=right_hand_prior, jaw_prior=jaw_prior, interpenetration=interpenetration, pen_distance=pen_distance, search_tree=search_tree, tri_filtering_module=filter_faces, s2m=s2m, m2s=m2s, rho_s2m=rho_s2m, rho_m2s=rho_m2s, head_mask=head_mask, body_mask=body_mask, sdf_penetration=sdf_penetration, voxel_size=voxel_size, grid_min=grid_min, grid_max=grid_max, sdf=sdf, sdf_normals=sdf_normals, R=R, t=t, contact=contact, contact_verts_ids=contact_verts_ids, rho_contact=rho_contact, contact_angle=contact_angle, dtype=dtype, **kwargs) loss = loss.to(device=device) with fitting.FittingMonitor(batch_size=batch_size, visualize=visualize, viz_mode=viz_mode, **kwargs) as monitor: img = torch.tensor(img, dtype=dtype) H, W, _ = img.shape # Reset the parameters to estimate the initial translation of the # body model if camera_mode == 'moving': body_model.reset_params(body_pose=body_mean_pose) # Update the value of the translation of the camera as well as # the image center. with torch.no_grad(): camera.translation[:] = init_t.view_as(camera.translation) camera.center[:] = torch.tensor([W, H], dtype=dtype) * 0.5 # Re-enable gradient calculation for the camera translation camera.translation.requires_grad = True camera_opt_params = [camera.translation, body_model.global_orient] elif camera_mode == 'fixed': body_model.reset_params(body_pose=body_mean_pose, transl=init_t) camera_opt_params = [body_model.transl, body_model.global_orient] # If the distance between the 2D shoulders is smaller than a # predefined threshold then try 2 fits, the initial one and a 180 # degree rotation shoulder_dist = torch.dist(gt_joints[:, left_shoulder_idx], gt_joints[:, right_shoulder_idx]) try_both_orient = shoulder_dist.item() < side_view_thsh camera_optimizer, camera_create_graph = optim_factory.create_optimizer( camera_opt_params, **kwargs) # The closure passed to the optimizer fit_camera = monitor.create_fitting_closure( camera_optimizer, body_model, camera, gt_joints, camera_loss, create_graph=camera_create_graph, use_vposer=use_vposer, vposer=vposer, pose_embedding=pose_embedding, scan_tensor=scan_tensor, return_full_pose=False, return_verts=False) # Step 1: Optimize over the torso joints the camera translation # Initialize the computational graph by feeding the initial translation # of the camera and the initial pose of the body model. camera_init_start = time.time() cam_init_loss_val = monitor.run_fitting(camera_optimizer, fit_camera, camera_opt_params, body_model, use_vposer=use_vposer, pose_embedding=pose_embedding, vposer=vposer) if interactive: if use_cuda and torch.cuda.is_available(): torch.cuda.synchronize() tqdm.write('Camera initialization done after {:.4f}'.format( time.time() - camera_init_start)) tqdm.write('Camera initialization final loss {:.4f}'.format( cam_init_loss_val)) # If the 2D detections/positions of the shoulder joints are too # close the rotate the body by 180 degrees and also fit to that # orientation if try_both_orient: body_orient = body_model.global_orient.detach().cpu().numpy() flipped_orient = cv2.Rodrigues(body_orient)[0].dot( cv2.Rodrigues(np.array([0., np.pi, 0]))[0]) flipped_orient = cv2.Rodrigues(flipped_orient)[0].ravel() flipped_orient = torch.tensor(flipped_orient, dtype=dtype, device=device).unsqueeze(dim=0) orientations = [body_orient, flipped_orient] else: orientations = [body_model.global_orient.detach().cpu().numpy()] # store here the final error for both orientations, # and pick the orientation resulting in the lowest error results = [] body_transl = body_model.transl.clone().detach() # Step 2: Optimize the full model final_loss_val = 0 for or_idx, orient in enumerate(tqdm(orientations, desc='Orientation')): opt_start = time.time() new_params = defaultdict(transl=body_transl, global_orient=orient, body_pose=body_mean_pose) body_model.reset_params(**new_params) if use_vposer: with torch.no_grad(): pose_embedding.fill_(0) for opt_idx, curr_weights in enumerate( tqdm(opt_weights, desc='Stage')): if opt_idx not in trans_opt_stages: body_model.transl.requires_grad = False else: body_model.transl.requires_grad = True body_params = list(body_model.parameters()) final_params = list( filter(lambda x: x.requires_grad, body_params)) if use_vposer: final_params.append(pose_embedding) body_optimizer, body_create_graph = optim_factory.create_optimizer( final_params, **kwargs) body_optimizer.zero_grad() curr_weights['bending_prior_weight'] = ( 3.17 * curr_weights['body_pose_weight']) if use_hands: joint_weights[:, 25:76] = curr_weights['hand_weight'] if use_face: joint_weights[:, 76:] = curr_weights['face_weight'] loss.reset_loss_weights(curr_weights) closure = monitor.create_fitting_closure( body_optimizer, body_model, camera=camera, gt_joints=gt_joints, joints_conf=joints_conf, joint_weights=joint_weights, loss=loss, create_graph=body_create_graph, use_vposer=use_vposer, vposer=vposer, pose_embedding=pose_embedding, scan_tensor=scan_tensor, scene_v=scene_v, scene_vn=scene_vn, scene_f=scene_f, ftov=ftov, return_verts=True, return_full_pose=True) if interactive: if use_cuda and torch.cuda.is_available(): torch.cuda.synchronize() stage_start = time.time() final_loss_val = monitor.run_fitting( body_optimizer, closure, final_params, body_model, pose_embedding=pose_embedding, vposer=vposer, use_vposer=use_vposer) if interactive: if use_cuda and torch.cuda.is_available(): torch.cuda.synchronize() elapsed = time.time() - stage_start if interactive: tqdm.write( 'Stage {:03d} done after {:.4f} seconds'.format( opt_idx, elapsed)) if interactive: if use_cuda and torch.cuda.is_available(): torch.cuda.synchronize() elapsed = time.time() - opt_start tqdm.write( 'Body fitting Orientation {} done after {:.4f} seconds'. format(or_idx, elapsed)) tqdm.write( 'Body final loss val = {:.5f}'.format(final_loss_val)) # Get the result of the fitting process # Store in it the errors list in order to compare multiple # orientations, if they exist result = { 'camera_' + str(key): val.detach().cpu().numpy() for key, val in camera.named_parameters() } result.update({ key: val.detach().cpu().numpy() for key, val in body_model.named_parameters() }) if use_vposer: result['pose_embedding'] = pose_embedding.detach().cpu().numpy( ) body_pose = vposer.decode(pose_embedding, output_type='aa').view( 1, -1) if use_vposer else None result['body_pose'] = body_pose.detach().cpu().numpy() results.append({'loss': final_loss_val, 'result': result}) with open(result_fn, 'wb') as result_file: if len(results) > 1: min_idx = (0 if results[0]['loss'] < results[1]['loss'] else 1) else: min_idx = 0 pickle.dump(results[min_idx]['result'], result_file, protocol=2) if save_meshes or visualize: body_pose = vposer.decode(pose_embedding, output_type='aa').view( 1, -1) if use_vposer else None model_type = kwargs.get('model_type', 'smpl') append_wrists = model_type == 'smpl' and use_vposer if append_wrists: wrist_pose = torch.zeros([body_pose.shape[0], 6], dtype=body_pose.dtype, device=body_pose.device) body_pose = torch.cat([body_pose, wrist_pose], dim=1) model_output = body_model(return_verts=True, body_pose=body_pose) vertices = model_output.vertices.detach().cpu().numpy().squeeze() import trimesh out_mesh = trimesh.Trimesh(vertices, body_model.faces, process=False) out_mesh.export(mesh_fn) if render_results: import pyrender # common H, W = 1080, 1920 camera_center = np.array([951.30, 536.77]) camera_pose = np.eye(4) camera_pose = np.array([1.0, -1.0, -1.0, 1.0]).reshape(-1, 1) * camera_pose camera = pyrender.camera.IntrinsicsCamera(fx=1060.53, fy=1060.38, cx=camera_center[0], cy=camera_center[1]) light = pyrender.DirectionalLight(color=np.ones(3), intensity=2.0) material = pyrender.MetallicRoughnessMaterial( metallicFactor=0.0, alphaMode='OPAQUE', baseColorFactor=(1.0, 1.0, 0.9, 1.0)) body_mesh = pyrender.Mesh.from_trimesh(out_mesh, material=material) ## rendering body img = img.detach().cpu().numpy() H, W, _ = img.shape scene = pyrender.Scene(bg_color=[0.0, 0.0, 0.0, 0.0], ambient_light=(0.3, 0.3, 0.3)) scene.add(camera, pose=camera_pose) scene.add(light, pose=camera_pose) # for node in light_nodes: # scene.add_node(node) scene.add(body_mesh, 'mesh') r = pyrender.OffscreenRenderer(viewport_width=W, viewport_height=H, point_size=1.0) color, _ = r.render(scene, flags=pyrender.RenderFlags.RGBA) color = color.astype(np.float32) / 255.0 valid_mask = (color[:, :, -1] > 0)[:, :, np.newaxis] input_img = img output_img = (color[:, :, :-1] * valid_mask + (1 - valid_mask) * input_img) img = pil_img.fromarray((output_img * 255).astype(np.uint8)) img.save(out_img_fn) ##redering body+scene body_mesh = pyrender.Mesh.from_trimesh(out_mesh, material=material) static_scene = trimesh.load(osp.join(scene_dir, scene_name + '.ply')) trans = np.linalg.inv(cam2world) static_scene.apply_transform(trans) static_scene_mesh = pyrender.Mesh.from_trimesh(static_scene) scene = pyrender.Scene() scene.add(camera, pose=camera_pose) scene.add(light, pose=camera_pose) scene.add(static_scene_mesh, 'mesh') scene.add(body_mesh, 'mesh') r = pyrender.OffscreenRenderer(viewport_width=W, viewport_height=H) color, _ = r.render(scene) color = color.astype(np.float32) / 255.0 img = pil_img.fromarray((color * 255).astype(np.uint8)) img.save(body_scene_rendering_fn)
def fit_SMPLXD(scans, smplx_pkl, gender='male', save_path=None, scale_file=None, interpenetration=True): search_tree = None pen_distance = None tri_filtering_module = None max_collisions = 128 df_cone_height = 0.0001 point2plane = False penalize_outside = True part_segm_fn = '/home/chen/IPNet_SMPLX/assets/smplx_parts_segm.pkl' ign_part_pairs = ["9,16", "9,17", "6,16", "6,17", "1,2", "12,22"] if interpenetration: from mesh_intersection.bvh_search_tree import BVH import mesh_intersection.loss as collisions_loss from mesh_intersection.filter_faces import FilterFaces search_tree = BVH(max_collisions=max_collisions) pen_distance = collisions_loss.DistanceFieldPenetrationLoss( sigma=df_cone_height, point2plane=point2plane, vectorized=True, penalize_outside=penalize_outside) if part_segm_fn: part_segm_fn = os.path.expandvars(part_segm_fn) with open(part_segm_fn, 'rb') as faces_parents_file: face_segm_data = pkl.load(faces_parents_file, encoding='latin1') faces_segm = face_segm_data['segm'] faces_parents = face_segm_data['parents'] tri_filtering_module = FilterFaces( faces_segm=faces_segm, faces_parents=faces_parents, ign_part_pairs=ign_part_pairs).cuda() # Get SMPLX faces # spx = SmplPaths(gender=gender) spx = SMPLX(model_path="/home/chen/SMPLX/models/smplx", batch_size=1, gender=gender) smplx_faces = spx.faces th_faces = torch.tensor(smplx_faces.astype('float32'), dtype=torch.long).cuda() # Batch size batch_sz = len(scans) # Init SMPLX global_pose, body_pose, left_hand_pose, right_hand_pose = [], [], [], [] expression, jaw_pose, leye_pose, reye_pose = [], [], [], [] betas, trans = [], [] for spkl in smplx_pkl: smplx_dict = pkl.load(open(spkl, 'rb')) g, bp, lh, rh, e, j, le, re, b, t = ( smplx_dict['global_pose'], smplx_dict['body_pose'], smplx_dict['left_hand_pose'], smplx_dict['right_hand_pose'], smplx_dict['expression'], smplx_dict['jaw_pose'], smplx_dict['leye_pose'], smplx_dict['reye_pose'], smplx_dict['betas'], smplx_dict['trans']) global_pose.append(g) body_pose.append(bp) left_hand_pose.append(lh) right_hand_pose.append(rh) expression.append(e) jaw_pose.append(j) leye_pose.append(le) reye_pose.append(re) if len(b) == 10: # temp = np.zeros((300,)) temp = np.zeros((10, )) temp[:10] = b b = temp.astype('float32') betas.append(b) trans.append(t) global_pose, body_pose, left_hand_pose, right_hand_pose = np.array(global_pose), np.array(body_pose), \ np.array(left_hand_pose), np.array(right_hand_pose) expression, jaw_pose, leye_pose, reye_pose = np.array(expression), np.array(jaw_pose), \ np.array(leye_pose), np.array(reye_pose) betas, trans = np.array(betas), np.array(trans) global_pose, body_pose, left_hand_pose, right_hand_pose = torch.tensor(global_pose), torch.tensor(body_pose), \ torch.tensor(left_hand_pose), torch.tensor(right_hand_pose) expression, jaw_pose, leye_pose, reye_pose = torch.tensor(expression), torch.tensor(jaw_pose), \ torch.tensor(leye_pose), torch.tensor(reye_pose) betas, trans = torch.tensor(betas), torch.tensor(trans) # smplx = th_batch_SMPLX(batch_sz, betas, pose, trans, faces=th_faces, gender=gender).cuda() smplx = th_batch_SMPLX(batch_sz, betas, global_pose, body_pose, left_hand_pose, right_hand_pose, trans, expression, jaw_pose, leye_pose, reye_pose, faces=th_faces, gender=gender).to(DEVICE) # verts, _, _, _ = smplx() verts = smplx() init_smplx_meshes = [ tm.from_tensors(vertices=v.clone().detach(), faces=smplx.faces) for v in verts ] # Load scans th_scan_meshes = [] for scan in scans: print('scan path ...', scan) temp = Mesh(filename=scan) th_scan = tm.from_tensors( torch.tensor(temp.v.astype('float32'), requires_grad=False, device=DEVICE), torch.tensor(temp.f.astype('int32'), requires_grad=False, device=DEVICE).long()) th_scan_meshes.append(th_scan) if scale_file is not None: for n, sc in enumerate(scale_file): dat = np.load(sc, allow_pickle=True) th_scan_meshes[n].vertices += torch.tensor(dat[1]).to(DEVICE) th_scan_meshes[n].vertices *= torch.tensor(dat[0]).to(DEVICE) # Optimize optimize_offsets(th_scan_meshes, smplx, init_smplx_meshes, 5, 10, search_tree, pen_distance, tri_filtering_module) # optimize_offsets_only(th_scan_meshes, smplx, init_smplx_meshes, 5, 8, search_tree, pen_distance, tri_filtering_module) print('Done') # verts, _, _, _ = smplx() verts = smplx.get_vertices_clean_hand() th_smplx_meshes = [ tm.from_tensors(vertices=v, faces=smplx.faces) for v in verts ] if save_path is not None: if not exists(save_path): os.makedirs(save_path) names = ['full.ply'] # [split(s)[1] for s in scans] # Save meshes save_meshes( th_smplx_meshes, [join(save_path, n.replace('.ply', '_smplxd.obj')) for n in names]) save_meshes(th_scan_meshes, [join(save_path, n) for n in names]) # Save params for g, bp, lh, rh, e, j, le, re, b, t, d, n in zip( smplx.global_pose.cpu().detach().numpy(), smplx.body_pose.cpu().detach().numpy(), smplx.left_hand_pose.cpu().detach().numpy(), smplx.right_hand_pose.cpu().detach().numpy(), smplx.expression.cpu().detach().numpy(), smplx.jaw_pose.cpu().detach().numpy(), smplx.leye_pose.cpu().detach().numpy(), smplx.reye_pose.cpu().detach().numpy(), smplx.betas.cpu().detach().numpy(), smplx.trans.cpu().detach().numpy(), smplx.offsets_clean_hand.cpu().detach().numpy(), names): smplx_dict = { 'global_pose': g, 'body_pose': bp, 'left_hand_pose': lh, 'right_hand_pose': rh, 'expression': e, 'jaw_pose': j, 'leye_pose': le, 'reye_pose': re, 'betas': b, 'trans': t, 'offsets': d } pkl.dump( smplx_dict, open(join(save_path, n.replace('.ply', '_smplxd.pkl')), 'wb')) return (smplx.global_pose.cpu().detach().numpy(), smplx.body_pose.cpu().detach().numpy(), smplx.left_hand_pose.cpu().detach().numpy(), smplx.right_hand_pose.cpu().detach().numpy(), smplx.expression.cpu().detach().numpy(), smplx.jaw_pose.cpu().detach().numpy(), smplx.leye_pose.cpu().detach().numpy(), smplx.reye_pose.cpu().detach().numpy(), smplx.betas.cpu().detach().numpy(), smplx.trans.cpu().detach().numpy(), smplx.offsets_clean_hand.cpu().detach().numpy())
def fit_SMPLX(scans, scan_labels, gender='male', save_path=None, scale_file=None, display=None, interpenetration=True): """ :param save_path: :param scans: list of scan paths :param pose_files: :return: """ search_tree = None pen_distance = None tri_filtering_module = None max_collisions = 128 df_cone_height = 0.0001 point2plane = False penalize_outside = True part_segm_fn = '/home/chen/IPNet_SMPLX/assets/smplx_parts_segm.pkl' ign_part_pairs = ["9,16", "9,17", "6,16", "6,17", "1,2", "12,22"] if interpenetration: from mesh_intersection.bvh_search_tree import BVH import mesh_intersection.loss as collisions_loss from mesh_intersection.filter_faces import FilterFaces search_tree = BVH(max_collisions=max_collisions) pen_distance = collisions_loss.DistanceFieldPenetrationLoss( sigma=df_cone_height, point2plane=point2plane, vectorized=True, penalize_outside=penalize_outside) if part_segm_fn: part_segm_fn = os.path.expandvars(part_segm_fn) with open(part_segm_fn, 'rb') as faces_parents_file: face_segm_data = pkl.load(faces_parents_file, encoding='latin1') faces_segm = face_segm_data['segm'] faces_parents = face_segm_data['parents'] tri_filtering_module = FilterFaces( faces_segm=faces_segm, faces_parents=faces_parents, ign_part_pairs=ign_part_pairs).cuda() # Get SMPLX faces # spx = SmplPaths(gender=gender) spx = SMPLX(model_path="/home/chen/SMPLX/models/smplx", batch_size=1, gender=gender) smplx_faces = spx.faces th_faces = torch.tensor(smplx_faces.astype('float32'), dtype=torch.long).to(DEVICE) # Load SMPLX parts part_labels = pkl.load( open('/home/chen/IPNet_SMPLX/assets/smplx_parts_dense.pkl', 'rb')) labels = np.zeros((10475, ), dtype='int32') for n, k in enumerate(part_labels): labels[part_labels[k]] = n labels = torch.tensor(labels).unsqueeze(0).to(DEVICE) # Load scan parts scan_part_labels = [] for sc_l in scan_labels: temp = torch.tensor(np.load(sc_l).astype('int32')).to(DEVICE) scan_part_labels.append(temp) # Batch size batch_sz = len(scans) # Set optimization hyper parameters iterations, pose_iterations, steps_per_iter, pose_steps_per_iter = 3, 2, 30, 30 # prior = get_prior(gender=gender, precomputed=True) if gender == 'male': temp_model = pkl.load(open( '/home/chen/SMPLX/models/smplx/SMPLX_MALE.pkl', 'rb'), encoding='latin1') elif gender == 'female': temp_model = pkl.load(open( '/home/chen/SMPLX/models/smplx/SMPLX_FEMALE.pkl', 'rb'), encoding='latin1') else: print('Wrong gender input!') exit() left_hand_mean = torch.tensor(temp_model['hands_meanl']).unsqueeze(0) right_hand_mean = torch.tensor(temp_model['hands_meanr']).unsqueeze(0) # pose_init = torch.zeros((batch_sz, 69)) # TODO consider to add the prior for smplx # pose_init[:, 3:] = prior.mean # betas, pose, trans = torch.zeros((batch_sz, 300)), pose_init, torch.zeros((batch_sz, 3)) betas, global_pose, body_pose, trans = torch.zeros( (batch_sz, 10)), torch.zeros((batch_sz, 3)), torch.zeros( (batch_sz, 63)), torch.zeros((batch_sz, 3)) left_hand_pose, right_hand_pose, expression, jaw_pose = left_hand_mean, right_hand_mean, torch.zeros( (batch_sz, 10)), torch.zeros((batch_sz, 3)) leye_pose, reye_pose = torch.zeros((batch_sz, 3)), torch.zeros( (batch_sz, 3)) # Init SMPLX, pose with mean smplx pose, as in ch.registration smplx = th_batch_SMPLX(batch_sz, betas, global_pose, body_pose, left_hand_pose, right_hand_pose, trans, expression, jaw_pose, leye_pose, reye_pose, faces=th_faces, gender=gender).to(DEVICE) smplx_part_labels = torch.cat([labels] * batch_sz, axis=0) th_scan_meshes, centers = [], [] for scan in scans: print('scan path ...', scan) temp = Mesh(filename=scan) th_scan = tm.from_tensors( torch.tensor(temp.v.astype('float32'), requires_grad=False, device=DEVICE), torch.tensor(temp.f.astype('int32'), requires_grad=False, device=DEVICE).long()) th_scan_meshes.append(th_scan) if scale_file is not None: for n, sc in enumerate(scale_file): dat = np.load(sc, allow_pickle=True) th_scan_meshes[n].vertices += torch.tensor(dat[1]).to(DEVICE) th_scan_meshes[n].vertices *= torch.tensor(dat[0]).to(DEVICE) # Optimize pose first optimize_pose_only(th_scan_meshes, smplx, pose_iterations, pose_steps_per_iter, scan_part_labels, smplx_part_labels, search_tree, pen_distance, tri_filtering_module, display=None if display is None else 0) # Optimize pose and shape optimize_pose_shape(th_scan_meshes, smplx, iterations, steps_per_iter, scan_part_labels, smplx_part_labels, search_tree, pen_distance, tri_filtering_module, display=None if display is None else 0) # verts, _, _, _ = smplx() verts = smplx() th_smplx_meshes = [ tm.from_tensors(vertices=v, faces=smplx.faces) for v in verts ] if save_path is not None: if not exists(save_path): os.makedirs(save_path) names = [split(s)[1] for s in scans] # Save meshes save_meshes( th_smplx_meshes, [join(save_path, n.replace('.ply', '_smplx.obj')) for n in names]) save_meshes(th_scan_meshes, [join(save_path, n) for n in names]) # Save params for g, bp, lh, rh, e, j, le, re, b, t, n in zip( smplx.global_pose.cpu().detach().numpy(), smplx.body_pose.cpu().detach().numpy(), smplx.left_hand_pose.cpu().detach().numpy(), smplx.right_hand_pose.cpu().detach().numpy(), smplx.expression.cpu().detach().numpy(), smplx.jaw_pose.cpu().detach().numpy(), smplx.leye_pose.cpu().detach().numpy(), smplx.reye_pose.cpu().detach().numpy(), smplx.betas.cpu().detach().numpy(), smplx.trans.cpu().detach().numpy(), names): smplx_dict = { 'global_pose': g, 'body_pose': bp, 'left_hand_pose': lh, 'right_hand_pose': rh, 'expression': e, 'jaw_pose': j, 'leye_pose': le, 'reye_pose': re, 'betas': b, 'trans': t } pkl.dump( smplx_dict, open(join(save_path, n.replace('.ply', '_smplx.pkl')), 'wb')) return (smplx.global_pose.cpu().detach().numpy(), smplx.body_pose.cpu().detach().numpy(), smplx.left_hand_pose.cpu().detach().numpy(), smplx.right_hand_pose.cpu().detach().numpy(), smplx.expression.cpu().detach().numpy(), smplx.jaw_pose.cpu().detach().numpy(), smplx.leye_pose.cpu().detach().numpy(), smplx.reye_pose.cpu().detach().numpy(), smplx.betas.cpu().detach().numpy(), smplx.trans.cpu().detach().numpy())
def fit_frames(img, keypoints, body_model, camera=None, joint_weights=None, body_pose_prior=None, jaw_prior=None, left_hand_prior=None, right_hand_prior=None, shape_prior=None, expr_prior=None, angle_prior=None, loss_type="smplify", use_cuda=True, init_joints_idxs=(9, 12, 2, 5), use_face=False, use_hands=True, data_weights=None, body_pose_prior_weights=None, hand_pose_prior_weights=None, jaw_pose_prior_weights=None, shape_weights=None, expr_weights=None, hand_joints_weights=None, face_joints_weights=None, depth_loss_weight=1e2, interpenetration=False, coll_loss_weights=None, df_cone_height=0.5, penalize_outside=True, max_collisions=8, point2plane=False, part_segm_fn="", focal_length=5000.0, side_view_thsh=25.0, rho=100, vposer_latent_dim=32, vposer_ckpt="", use_joints_conf=False, interactive=True, visualize=False, batch_size=1, dtype=torch.float32, ign_part_pairs=None, left_shoulder_idx=2, right_shoulder_idx=5, freeze_camera=True, **kwargs): # assert batch_size == 1, "PyTorch L-BFGS only supports batch_size == 1" device = torch.device("cuda") if use_cuda else torch.device("cpu") if body_pose_prior_weights is None: body_pose_prior_weights = [4.04 * 1e2, 4.04 * 1e2, 57.4, 4.78] if data_weights is None: data_weights = [1] * len(body_pose_prior_weights) msg = "Number of Body pose prior weights {}".format( len(body_pose_prior_weights) ) + " does not match the number of data term weights {}".format( len(data_weights)) assert len(data_weights) == len(body_pose_prior_weights), msg if use_hands: if hand_pose_prior_weights is None: hand_pose_prior_weights = [1e2, 5 * 1e1, 1e1, 0.5 * 1e1] msg = ("Number of Body pose prior weights does not match the" + " number of hand pose prior weights") assert len(hand_pose_prior_weights) == len( body_pose_prior_weights), msg if hand_joints_weights is None: hand_joints_weights = [0.0, 0.0, 0.0, 1.0] msg = ("Number of Body pose prior weights does not match the" + " number of hand joint distance weights") assert len(hand_joints_weights) == len( body_pose_prior_weights), msg if shape_weights is None: shape_weights = [1e2, 5 * 1e1, 1e1, 0.5 * 1e1] msg = ("Number of Body pose prior weights = {} does not match the" + " number of Shape prior weights = {}") assert len(shape_weights) == len(body_pose_prior_weights), msg.format( len(shape_weights), len(body_pose_prior_weights)) if coll_loss_weights is None: coll_loss_weights = [0.0] * len(body_pose_prior_weights) msg = ("Number of Body pose prior weights does not match the" + " number of collision loss weights") assert len(coll_loss_weights) == len(body_pose_prior_weights), msg use_vposer = kwargs.get("use_vposer", True) vposer, pose_embedding = [None] * 2 if use_vposer: pose_embedding = torch.zeros( [batch_size, vposer_latent_dim], dtype=dtype, device=device, requires_grad=True, ) vposer_ckpt = osp.expandvars(vposer_ckpt) vposer, _ = load_vposer(vposer_ckpt, vp_model="snapshot") vposer = vposer.to(device=device) vposer.eval() if use_vposer: body_mean_pose = ( vposeutils.get_vposer_mean_pose().detach().cpu().numpy()) else: body_mean_pose = body_pose_prior.get_mean().detach().cpu() keypoint_data = torch.tensor(keypoints, dtype=dtype) gt_joints = keypoint_data[:, :, :2] if use_joints_conf: joints_conf = keypoint_data[:, :, 2].reshape(keypoint_data.shape[0], -1) joints_conf = joints_conf.to(device=device, dtype=dtype) # Transfer the data to the correct device gt_joints = gt_joints.to(device=device, dtype=dtype) # Create the search tree search_tree = None pen_distance = None filter_faces = None if interpenetration: from mesh_intersection.bvh_search_tree import BVH import mesh_intersection.loss as collisions_loss from mesh_intersection.filter_faces import FilterFaces assert use_cuda, "Interpenetration term can only be used with CUDA" assert torch.cuda.is_available(), ( "No CUDA Device! Interpenetration term can only be used" + " with CUDA") search_tree = BVH(max_collisions=max_collisions) pen_distance = collisions_loss.DistanceFieldPenetrationLoss( sigma=df_cone_height, point2plane=point2plane, vectorized=True, penalize_outside=penalize_outside, ) if part_segm_fn: # Read the part segmentation part_segm_fn = os.path.expandvars(part_segm_fn) with open(part_segm_fn, "rb") as faces_parents_file: face_segm_data = pickle.load(faces_parents_file, encoding="latin1") faces_segm = face_segm_data["segm"] faces_parents = face_segm_data["parents"] # Create the module used to filter invalid collision pairs filter_faces = FilterFaces( faces_segm=faces_segm, faces_parents=faces_parents, ign_part_pairs=ign_part_pairs, ).to(device=device) # Weights used for the pose prior and the shape prior opt_weights_dict = { "data_weight": data_weights, "body_pose_weight": body_pose_prior_weights, "shape_weight": shape_weights, } if use_face: opt_weights_dict["face_weight"] = face_joints_weights opt_weights_dict["expr_prior_weight"] = expr_weights opt_weights_dict["jaw_prior_weight"] = jaw_pose_prior_weights if use_hands: opt_weights_dict["hand_weight"] = hand_joints_weights opt_weights_dict["hand_prior_weight"] = hand_pose_prior_weights if interpenetration: opt_weights_dict["coll_loss_weight"] = coll_loss_weights keys = opt_weights_dict.keys() opt_weights = [ dict(zip(keys, vals)) for vals in zip(*(opt_weights_dict[k] for k in keys if opt_weights_dict[k] is not None)) ] for weight_list in opt_weights: for key in weight_list: weight_list[key] = torch.tensor(weight_list[key], device=device, dtype=dtype) # The indices of the joints used for the initialization of the camera init_joints_idxs = torch.tensor(init_joints_idxs, device=device) # Hand joints start at 25 (before body) loss = fitting.create_loss(loss_type=loss_type, joint_weights=joint_weights, rho=rho, use_joints_conf=use_joints_conf, use_face=use_face, use_hands=use_hands, vposer=vposer, pose_embedding=pose_embedding, body_pose_prior=body_pose_prior, shape_prior=shape_prior, angle_prior=angle_prior, expr_prior=expr_prior, left_hand_prior=left_hand_prior, right_hand_prior=right_hand_prior, jaw_prior=jaw_prior, interpenetration=interpenetration, pen_distance=pen_distance, search_tree=search_tree, tri_filtering_module=filter_faces, dtype=dtype, **kwargs) loss = loss.to(device=device) with fitting.FittingMonitor(batch_size=batch_size, visualize=visualize, **kwargs) as monitor: img = torch.tensor(img, dtype=dtype) H, W, _ = img.shape data_weight = 1000 / H orientations = [body_model.global_orient.detach().cpu().numpy()] # # Step 2: Optimize the full model final_loss_val = 0 for or_idx, orient in enumerate(tqdm(orientations, desc="Orientation")): opt_start = time.time() new_params = defaultdict(global_orient=orient, body_pose=body_mean_pose) body_model.reset_params(**new_params) if use_vposer: with torch.no_grad(): pose_embedding.fill_(0) for opt_idx, curr_weights in enumerate( tqdm(opt_weights, desc="Stage")): body_params = list(body_model.parameters()) final_params = list( filter(lambda x: x.requires_grad, body_params)) if use_vposer: final_params.append(pose_embedding) ( body_optimizer, body_create_graph, ) = optim_factory.create_optimizer(final_params, **kwargs) body_optimizer.zero_grad() curr_weights["data_weight"] = data_weight curr_weights["bending_prior_weight"] = ( 3.17 * curr_weights["body_pose_weight"]) if use_hands: # joint_weights[:, 25:67] = curr_weights['hand_weight'] pass if use_face: joint_weights[:, 67:] = curr_weights["face_weight"] loss.reset_loss_weights(curr_weights) closure = monitor.create_fitting_closure( body_optimizer, body_model, camera=camera, gt_joints=gt_joints, joints_conf=joints_conf, joint_weights=joint_weights, loss=loss, create_graph=body_create_graph, use_vposer=use_vposer, vposer=vposer, pose_embedding=pose_embedding, return_verts=True, return_full_pose=True, ) if interactive: if use_cuda and torch.cuda.is_available(): torch.cuda.synchronize() stage_start = time.time() final_loss_val = monitor.run_fitting( body_optimizer, closure, final_params, body_model, pose_embedding=pose_embedding, vposer=vposer, use_vposer=use_vposer, ) if interactive: if use_cuda and torch.cuda.is_available(): torch.cuda.synchronize() elapsed = time.time() - stage_start if interactive: tqdm.write( "Stage {:03d} done after {:.4f} seconds".format( opt_idx, elapsed)) if interactive: if use_cuda and torch.cuda.is_available(): torch.cuda.synchronize() elapsed = time.time() - opt_start tqdm.write( "Body fitting Orientation {} done after {:.4f} seconds". format(or_idx, elapsed)) tqdm.write( "Body final loss val = {:.5f}".format(final_loss_val)) # Get the result of the fitting process # Store in it the errors list in order to compare multiple # orientations, if they exist result = { "camera_" + str(key): val.detach().cpu().numpy() for key, val in camera.named_parameters() } result.update({ key: val.detach().cpu().numpy() for key, val in body_model.named_parameters() }) if use_vposer: result["pose_embedding"] = ( pose_embedding.detach().cpu().numpy()) body_pose = (vposer.decode( pose_embedding, output_type="aa").reshape( pose_embedding.shape[0], -1) if use_vposer else None) result["body_pose"] = body_pose.detach().cpu().numpy() model_output = body_model(return_verts=True, body_pose=body_pose) return model_output, result
def fit_single_frame(keypoints, body_model, joint_weights, body_pose_prior, jaw_prior, left_hand_prior, right_hand_prior, shape_prior, expr_prior, angle_prior, result_fn='out.pkl', mesh_fn='out.obj', out_img_fn='overlay.png', loss_type='smplify', use_cuda=True, init_joints_idxs=(9, 12, 2, 5), use_face=True, use_hands=True, data_weights=None, body_pose_prior_weights=None, hand_pose_prior_weights=None, jaw_pose_prior_weights=None, shape_weights=None, expr_weights=None, hand_joints_weights=None, face_joints_weights=None, depth_loss_weight=1e2, interpenetration=True, coll_loss_weights=None, df_cone_height=0.5, penalize_outside=True, max_collisions=8, point2plane=False, part_segm_fn='', focal_length=5000., side_view_thsh=25., rho=100, vposer_latent_dim=32, vposer_ckpt='', use_joints_conf=False, interactive=True, visualize=False, save_meshes=True, degrees=None, batch_size=1, dtype=torch.float32, ign_part_pairs=None, left_shoulder_idx=2, right_shoulder_idx=5, **kwargs): assert batch_size == 1, 'PyTorch L-BFGS only supports batch_size == 1' device = torch.device('cuda') if use_cuda else torch.device('cpu') if data_weights is None: data_weights = [1, ] * 5 if body_pose_prior_weights is None: body_pose_prior_weights = [4.04 * 1e2, 4.04 * 1e2, 57.4, 4.78] msg = ( 'Number of Body pose prior weights {}'.format( len(body_pose_prior_weights)) + ' does not match the number of data term weights {}'.format( len(data_weights))) assert (len(data_weights) == len(body_pose_prior_weights)), msg if use_hands: if hand_pose_prior_weights is None: hand_pose_prior_weights = [1e2, 5 * 1e1, 1e1, .5 * 1e1] msg = ('Number of Body pose prior weights does not match the' + ' number of hand pose prior weights') assert (len(hand_pose_prior_weights) == len(body_pose_prior_weights)), msg if hand_joints_weights is None: hand_joints_weights = [0.0, 0.0, 0.0, 1.0] msg = ('Number of Body pose prior weights does not match the' + ' number of hand joint distance weights') assert (len(hand_joints_weights) == len(body_pose_prior_weights)), msg if shape_weights is None: shape_weights = [1e2, 5 * 1e1, 1e1, .5 * 1e1] msg = ('Number of Body pose prior weights = {} does not match the' + ' number of Shape prior weights = {}') assert (len(shape_weights) == len(body_pose_prior_weights)), msg.format( len(shape_weights), len(body_pose_prior_weights)) if use_face: if jaw_pose_prior_weights is None: jaw_pose_prior_weights = [[x] * 3 for x in shape_weights] else: jaw_pose_prior_weights = map(lambda x: map(float, x.split(',')), jaw_pose_prior_weights) jaw_pose_prior_weights = [list(w) for w in jaw_pose_prior_weights] msg = ('Number of Body pose prior weights does not match the' + ' number of jaw pose prior weights') assert (len(jaw_pose_prior_weights) == len(body_pose_prior_weights)), msg if expr_weights is None: expr_weights = [1e2, 5 * 1e1, 1e1, .5 * 1e1] msg = ('Number of Body pose prior weights = {} does not match the' + ' number of Expression prior weights = {}') assert (len(expr_weights) == len(body_pose_prior_weights)), msg.format( len(body_pose_prior_weights), len(expr_weights)) if face_joints_weights is None: face_joints_weights = [0.0, 0.0, 0.0, 1.0] msg = ('Number of Body pose prior weights does not match the' + ' number of face joint distance weights') assert (len(face_joints_weights) == len(body_pose_prior_weights)), msg if coll_loss_weights is None: coll_loss_weights = [0.0] * len(body_pose_prior_weights) msg = ('Number of Body pose prior weights does not match the' + ' number of collision loss weights') assert (len(coll_loss_weights) == len(body_pose_prior_weights)), msg use_vposer = kwargs.get('use_vposer', True) vposer, pose_embedding = [None, ] * 2 if use_vposer: pose_embedding = torch.zeros([batch_size, 32], dtype=dtype, device=device, requires_grad=True) vposer_ckpt = osp.expandvars(vposer_ckpt) vposer, _ = load_vposer(vposer_ckpt, vp_model='snapshot') vposer = vposer.to(device=device) vposer.eval() if use_vposer: body_mean_pose = torch.zeros([batch_size, vposer_latent_dim], dtype=dtype) else: body_mean_pose = body_pose_prior.get_mean().detach().cpu() keypoint_data = torch.tensor(keypoints, dtype=dtype) gt_joints = keypoint_data[:, :, :-1] if use_joints_conf: joints_conf = keypoint_data[:, :, -1].reshape(1, -1) # Transfer the data to the correct device gt_joints = gt_joints.to(device=device, dtype=dtype) if use_joints_conf: joints_conf = joints_conf.to(device=device, dtype=dtype) # Create the search tree search_tree = None pen_distance = None filter_faces = None if interpenetration: from mesh_intersection.bvh_search_tree import BVH import mesh_intersection.loss as collisions_loss from mesh_intersection.filter_faces import FilterFaces assert use_cuda, 'Interpenetration term can only be used with CUDA' assert torch.cuda.is_available(), \ 'No CUDA Device! Interpenetration term can only be used' + \ ' with CUDA' search_tree = BVH(max_collisions=max_collisions) pen_distance = \ collisions_loss.DistanceFieldPenetrationLoss( sigma=df_cone_height, point2plane=point2plane, vectorized=True, penalize_outside=penalize_outside) if part_segm_fn: # Read the part segmentation part_segm_fn = os.path.expandvars(part_segm_fn) with open(part_segm_fn, 'rb') as faces_parents_file: face_segm_data = pickle.load(faces_parents_file, encoding='latin1') faces_segm = face_segm_data['segm'] faces_parents = face_segm_data['parents'] # Create the module used to filter invalid collision pairs filter_faces = FilterFaces( faces_segm=faces_segm, faces_parents=faces_parents, ign_part_pairs=ign_part_pairs).to(device=device) # Weights used for the pose prior and the shape prior opt_weights_dict = {'data_weight': data_weights, 'body_pose_weight': body_pose_prior_weights, 'shape_weight': shape_weights} if use_face: opt_weights_dict['face_weight'] = face_joints_weights opt_weights_dict['expr_prior_weight'] = expr_weights opt_weights_dict['jaw_prior_weight'] = jaw_pose_prior_weights if use_hands: opt_weights_dict['hand_weight'] = hand_joints_weights opt_weights_dict['hand_prior_weight'] = hand_pose_prior_weights if interpenetration: opt_weights_dict['coll_loss_weight'] = coll_loss_weights keys = opt_weights_dict.keys() opt_weights = [dict(zip(keys, vals)) for vals in zip(*(opt_weights_dict[k] for k in keys if opt_weights_dict[k] is not None))] for weight_list in opt_weights: for key in weight_list: weight_list[key] = torch.tensor(weight_list[key], device=device, dtype=dtype) loss = fitting.create_loss(loss_type=loss_type, joint_weights=joint_weights, rho=rho, use_joints_conf=use_joints_conf, use_face=use_face, use_hands=use_hands, vposer=vposer, pose_embedding=pose_embedding, body_pose_prior=body_pose_prior, shape_prior=shape_prior, angle_prior=angle_prior, expr_prior=expr_prior, left_hand_prior=left_hand_prior, right_hand_prior=right_hand_prior, jaw_prior=jaw_prior, interpenetration=interpenetration, pen_distance=pen_distance, search_tree=search_tree, tri_filtering_module=filter_faces, dtype=dtype, **kwargs) loss = loss.to(device=device) with fitting.FittingMonitor( batch_size=batch_size, visualize=visualize, **kwargs) as monitor: data_weight = 0.7 # Reset the parameters to estimate the initial translation of the # body model body_model.reset_params(body_pose=body_mean_pose) orientations = [body_model.global_orient.detach().cpu().numpy()] # store here the final error for both orientations, # and pick the orientation resulting in the lowest error results = [] # Optimize the full model final_loss_val = 0 opt_start = time.time() new_params = defaultdict(body_pose=body_mean_pose) body_model.reset_params(**new_params) if use_vposer: with torch.no_grad(): pose_embedding.fill_(0) for opt_idx, curr_weights in enumerate(tqdm(opt_weights, desc='Stage')): body_params = list(body_model.parameters()) final_params = list( filter(lambda x: x.requires_grad, body_params)) if use_vposer: final_params.append(pose_embedding) body_optimizer, body_create_graph = optim_factory.create_optimizer( final_params, **kwargs) body_optimizer.zero_grad() curr_weights['data_weight'] = data_weight curr_weights['bending_prior_weight'] = ( 3.17 * curr_weights['body_pose_weight']) if use_hands: joint_weights[:, 25:67] = curr_weights['hand_weight'] if use_face: joint_weights[:, 67:] = curr_weights['face_weight'] loss.reset_loss_weights(curr_weights) closure = monitor.create_fitting_closure( body_optimizer, body_model, gt_joints=gt_joints, joints_conf=joints_conf, joint_weights=joint_weights, loss=loss, create_graph=body_create_graph, use_vposer=use_vposer, vposer=vposer, pose_embedding=pose_embedding, return_verts=True, return_full_pose=True) if interactive: if use_cuda and torch.cuda.is_available(): torch.cuda.synchronize() stage_start = time.time() final_loss_val = monitor.run_fitting( body_optimizer, closure, final_params, body_model, pose_embedding=pose_embedding, vposer=vposer, use_vposer=use_vposer) if interactive: if use_cuda and torch.cuda.is_available(): torch.cuda.synchronize() elapsed = time.time() - stage_start if interactive: tqdm.write('Stage {:03d} done after {:.4f} seconds'.format( opt_idx, elapsed)) if interactive: if use_cuda and torch.cuda.is_available(): torch.cuda.synchronize() elapsed = time.time() - opt_start tqdm.write( 'Body fitting Orientation done after {:.4f} seconds'.format(elapsed)) tqdm.write('Body final loss val = {:.5f}'.format( final_loss_val)) # Get the result of the fitting process # Store in it the errors list in order to compare multiple # orientations, if they exist result = {} result.update({key: val.detach().cpu().numpy() for key, val in body_model.named_parameters()}) if use_vposer: result['body_pose'] = pose_embedding.detach().cpu().numpy() results = {'loss': final_loss_val, 'result': result} with open(result_fn, 'wb') as result_file: pickle.dump(results['result'], result_file, protocol=2) if save_meshes or visualize: body_pose = vposer.decode( pose_embedding, output_type='aa').view(1, -1) if use_vposer else None model_type = kwargs.get('model_type', 'smpl') append_wrists = model_type == 'smpl' and use_vposer if append_wrists: wrist_pose = torch.zeros([body_pose.shape[0], 6], dtype=body_pose.dtype, device=body_pose.device) body_pose = torch.cat([body_pose, wrist_pose], dim=1) model_output = body_model(return_verts=True, body_pose=body_pose) vertices = model_output.vertices.detach().cpu().numpy().squeeze() import trimesh out_mesh = trimesh.Trimesh(vertices, body_model.faces, process=False) rot = trimesh.transformations.rotation_matrix( np.radians(180), [1, 0, 0]) out_mesh.apply_transform(rot) out_mesh.export(mesh_fn)
def __call__(self, init_pose, init_betas, init_cam_t, j3d, fit_beta=False, conf_3d=1.0): """Perform body fitting. Input: init_pose: SMPL pose estimate init_betas: SMPL betas estimate init_cam_t: Camera translation estimate j3d: joints 3d aka keypoints conf_3d: confidence for 3d joints Returns: vertices: Vertices of optimized shape joints: 3D joints of optimized shape pose: SMPL pose parameters of optimized shape betas: SMPL beta parameters of optimized shape camera_translation: Camera translation """ # add the mesh inter-section to avoid search_tree = None pen_distance = None filter_faces = None if self.use_collision: from mesh_intersection.bvh_search_tree import BVH import mesh_intersection.loss as collisions_loss from mesh_intersection.filter_faces import FilterFaces search_tree = BVH(max_collisions=8) pen_distance = collisions_loss.DistanceFieldPenetrationLoss( sigma=0.5, point2plane=False, vectorized=True, penalize_outside=True) if self.part_segm_fn: # Read the part segmentation part_segm_fn = os.path.expandvars(self.part_segm_fn) with open(part_segm_fn, 'rb') as faces_parents_file: face_segm_data = pickle.load(faces_parents_file, encoding='latin1') faces_segm = face_segm_data['segm'] faces_parents = face_segm_data['parents'] # Create the module used to filter invalid collision pairs filter_faces = FilterFaces( faces_segm=faces_segm, faces_parents=faces_parents, ign_part_pairs=None).to(device=self.device) # Split SMPL pose to body pose and global orientation body_pose = init_pose[:, 3:].detach().clone() global_orient = init_pose[:, :3].detach().clone() betas = init_betas.detach().clone() # use guess 3d to get the initial smpl_output = self.smpl(global_orient=global_orient, body_pose=body_pose, betas=betas) model_joints = smpl_output.joints init_cam_t = self.guess_init_3d(model_joints, j3d).detach() camera_translation = init_cam_t.clone() preserve_pose = init_pose[:, 3:].detach().clone() # -------------Step 1: Optimize camera translation and body orientation-------- # Optimize only camera translation and body orientation body_pose.requires_grad = False betas.requires_grad = False global_orient.requires_grad = True camera_translation.requires_grad = True camera_opt_params = [global_orient, camera_translation] if self.use_lbfgs: camera_optimizer = torch.optim.LBFGS(camera_opt_params, max_iter=self.num_iters, lr=self.step_size, line_search_fn='strong_wolfe') for i in range(10): def closure(): camera_optimizer.zero_grad() smpl_output = self.smpl(global_orient=global_orient, body_pose=body_pose, betas=betas) model_joints = smpl_output.joints loss = self.camera_fitting_loss_3d(model_joints, camera_translation, init_cam_t, j3d) loss.backward() return loss camera_optimizer.step(closure) else: camera_optimizer = torch.optim.Adam(camera_opt_params, lr=self.step_size, betas=(0.9, 0.999)) for i in range(20): smpl_output = self.smpl(global_orient=global_orient, body_pose=body_pose, betas=betas) model_joints = smpl_output.joints loss = self.camera_fitting_loss_3d( model_joints[:, self.smpl_index], camera_translation, init_cam_t, j3d[:, self.corr_index]) camera_optimizer.zero_grad() loss.backward() camera_optimizer.step() # Fix camera translation after optimizing camera # --------Step 2: Optimize body joints -------------------------- # Optimize only the body pose and global orientation of the body body_pose.requires_grad = True global_orient.requires_grad = True camera_translation.requires_grad = True # --- if we use the sequence, fix the shape if fit_beta: betas.requires_grad = True body_opt_params = [ body_pose, betas, global_orient, camera_translation ] else: betas.requires_grad = False body_opt_params = [body_pose, global_orient, camera_translation] if self.use_lbfgs: body_optimizer = torch.optim.LBFGS(body_opt_params, max_iter=self.num_iters, lr=self.step_size, line_search_fn='strong_wolfe') for i in range(self.num_iters): def closure(): body_optimizer.zero_grad() smpl_output = self.smpl(global_orient=global_orient, body_pose=body_pose, betas=betas) model_joints = smpl_output.joints model_vertices = smpl_output.vertices loss = self.body_fitting_loss_3d( body_pose, preserve_pose, betas, model_joints[:, self.smpl_index], camera_translation, j3d[:, self.corr_index], self.pose_prior, joints3d_conf=conf_3d, joint_loss_weight=600.0, pose_preserve_weight=5.0, use_collision=self.use_collision, model_vertices=model_vertices, model_faces=self.model_faces, search_tree=search_tree, pen_distance=pen_distance, filter_faces=filter_faces) loss.backward() return loss body_optimizer.step(closure) else: body_optimizer = torch.optim.Adam(body_opt_params, lr=self.step_size, betas=(0.9, 0.999)) for i in range(self.num_iters): smpl_output = self.smpl(global_orient=global_orient, body_pose=body_pose, betas=betas) model_joints = smpl_output.joints model_vertices = smpl_output.vertices loss = self.body_fitting_loss_3d( body_pose, preserve_pose, betas, model_joints[:, self.smpl_index], camera_translation, j3d[:, self.corr_index], self.pose_prior, joints3d_conf=conf_3d, joint_loss_weight=600.0, use_collision=self.use_collision, model_vertices=model_vertices, model_faces=self.model_faces, search_tree=search_tree, pen_distance=pen_distance, filter_faces=filter_faces) body_optimizer.zero_grad() loss.backward() body_optimizer.step() # Get final loss value with torch.no_grad(): smpl_output = self.smpl(global_orient=global_orient, body_pose=body_pose, betas=betas, return_full_pose=True) model_joints = smpl_output.joints model_vertices = smpl_output.vertices final_loss = self.body_fitting_loss_3d( body_pose, preserve_pose, betas, model_joints[:, self.smpl_index], camera_translation, j3d[:, self.corr_index], self.pose_prior, joints3d_conf=conf_3d, joint_loss_weight=600.0, use_collision=self.use_collision, model_vertices=model_vertices, model_faces=self.model_faces, search_tree=search_tree, pen_distance=pen_distance, filter_faces=filter_faces) vertices = smpl_output.vertices.detach() joints = smpl_output.joints.detach() pose = torch.cat([global_orient, body_pose], dim=-1).detach() betas = betas.detach() camera_translation = camera_translation.detach() return vertices, joints, pose, betas, camera_translation, final_loss
def fit_single_frame(img, keypoints, init_trans, scan, scene_name, body_model, camera, joint_weights, body_pose_prior, jaw_prior, left_hand_prior, right_hand_prior, shape_prior, expr_prior, angle_prior, result_fn='out.pkl', mesh_fn='out.obj', body_scene_rendering_fn='body_scene.png', out_img_fn='overlay.png', loss_type='smplify', use_cuda=True, init_joints_idxs=(9, 12, 2, 5), use_face=True, use_hands=True, data_weights=None, body_pose_prior_weights=None, hand_pose_prior_weights=None, jaw_pose_prior_weights=None, shape_weights=None, expr_weights=None, hand_joints_weights=None, face_joints_weights=None, depth_loss_weight=1e2, interpenetration=True, coll_loss_weights=None, df_cone_height=0.5, penalize_outside=True, max_collisions=8, point2plane=False, part_segm_fn='', focal_length_x=5000., focal_length_y=5000., side_view_thsh=25., rho=100, vposer_latent_dim=32, vposer_ckpt='', use_joints_conf=False, interactive=True, visualize=False, save_meshes=True, degrees=None, batch_size=1, dtype=torch.float32, ign_part_pairs=None, left_shoulder_idx=2, right_shoulder_idx=5, #################### ### PROX render_results=True, camera_mode='moving', ## Depth s2m=False, s2m_weights=None, m2s=False, m2s_weights=None, rho_s2m=1, rho_m2s=1, init_mode=None, trans_opt_stages=None, viz_mode='mv', #penetration sdf_penetration=False, sdf_penetration_weights=0.0, sdf_dir=None, cam2world_dir=None, #contact contact=False, rho_contact=1.0, contact_loss_weights=None, contact_angle=15, contact_body_parts=None, body_segments_dir=None, load_scene=False, scene_dir=None, height=None, weight=None, gender='male', weight_w=0, height_w=0, **kwargs): if kwargs['optim_type'] == 'lbfgsls': assert batch_size == 1, 'PyTorch L-BFGS only supports batch_size == 1' batch_size = keypoints.shape[0] body_model.reset_params() body_model.transl.requires_grad = True device = torch.device('cuda') if use_cuda else torch.device('cpu') # if visualize: # pil_img.fromarray((img * 255).astype(np.uint8)).show() if degrees is None: degrees = [0, 90, 180, 270] if data_weights is None: data_weights = [1, ] * 5 if body_pose_prior_weights is None: body_pose_prior_weights = [4.04 * 1e2, 4.04 * 1e2, 57.4, 4.78] msg = ( 'Number of Body pose prior weights {}'.format( len(body_pose_prior_weights)) + ' does not match the number of data term weights {}'.format( len(data_weights))) assert (len(data_weights) == len(body_pose_prior_weights)), msg if use_hands: if hand_pose_prior_weights is None: hand_pose_prior_weights = [1e2, 5 * 1e1, 1e1, .5 * 1e1] msg = ('Number of Body pose prior weights does not match the' + ' number of hand pose prior weights') assert (len(hand_pose_prior_weights) == len(body_pose_prior_weights)), msg if hand_joints_weights is None: hand_joints_weights = [0.0, 0.0, 0.0, 1.0] msg = ('Number of Body pose prior weights does not match the' + ' number of hand joint distance weights') assert (len(hand_joints_weights) == len(body_pose_prior_weights)), msg if shape_weights is None: shape_weights = [1e2, 5 * 1e1, 1e1, .5 * 1e1] msg = ('Number of Body pose prior weights = {} does not match the' + ' number of Shape prior weights = {}') assert (len(shape_weights) == len(body_pose_prior_weights)), msg.format( len(shape_weights), len(body_pose_prior_weights)) if use_face: if jaw_pose_prior_weights is None: jaw_pose_prior_weights = [[x] * 3 for x in shape_weights] else: jaw_pose_prior_weights = map(lambda x: map(float, x.split(',')), jaw_pose_prior_weights) jaw_pose_prior_weights = [list(w) for w in jaw_pose_prior_weights] msg = ('Number of Body pose prior weights does not match the' + ' number of jaw pose prior weights') assert (len(jaw_pose_prior_weights) == len(body_pose_prior_weights)), msg if expr_weights is None: expr_weights = [1e2, 5 * 1e1, 1e1, .5 * 1e1] msg = ('Number of Body pose prior weights = {} does not match the' + ' number of Expression prior weights = {}') assert (len(expr_weights) == len(body_pose_prior_weights)), msg.format( len(body_pose_prior_weights), len(expr_weights)) if face_joints_weights is None: face_joints_weights = [0.0, 0.0, 0.0, 1.0] msg = ('Number of Body pose prior weights does not match the' + ' number of face joint distance weights') assert (len(face_joints_weights) == len(body_pose_prior_weights)), msg if coll_loss_weights is None: coll_loss_weights = [0.0] * len(body_pose_prior_weights) msg = ('Number of Body pose prior weights does not match the' + ' number of collision loss weights') assert (len(coll_loss_weights) == len(body_pose_prior_weights)), msg use_vposer = kwargs.get('use_vposer', True) vposer, pose_embedding = [None, ] * 2 if use_vposer: # pose_embedding = torch.zeros([batch_size, 32], # dtype=dtype, device=device, # requires_grad=True) # Patrick: hack to set default body pose to something more sleep-y mean_body = np.array([[ 0.19463745, 1.6240447, 0.6890624, 0.19186097, 0.08003145, -0.04189298, 3.450903, -0.29570094, 0.25072002, -1.1879578, 0.33350763, 0.23568614, 0.38122794, -2.1258948, 0.2910664, 2.2407222, -0.5400814, -0.95984083, -1.2880017, 1.1122228, 0.7411389, -0.2265636, -4.8202057, -1.950323, -0.28771818, -1.9282387, 0.9928907, -0.27183488, -0.55805033, 0.04047768, -0.537362, 0.65770334]]) pose_embedding = torch.tensor(mean_body, dtype=dtype, device=device, requires_grad=True) vposer_ckpt = osp.expandvars(vposer_ckpt) vposer, _ = load_vposer(vposer_ckpt, vp_model='snapshot') vposer = vposer.to(device=device) vposer.eval() if use_vposer: body_mean_pose = torch.zeros([batch_size, vposer_latent_dim], dtype=dtype) else: # body_mean_pose = body_pose_prior.get_mean().detach().cpu() # body_mean_pose = torch.zeros([batch_size, 69], dtype=dtype) # mean_body = np.array([[-2.33263850e-01, 1.35460928e-01, 2.94471830e-01, -3.22930813e-01, # -4.73931670e-01, -2.67531037e-01, 7.12558180e-02, 7.89440796e-03, # 8.67700949e-03, 1.05982251e-01, 2.79584467e-01, -7.04243258e-02, # 3.61106455e-01, -5.87305248e-01, 1.10897996e-01, -1.68918714e-01, # -4.60174456e-02, 3.28684039e-02, 5.80525696e-01, -5.11317095e-03, # -1.57546505e-01, 5.85777402e-01, -8.94948393e-02, 2.24680841e-01, # 1.55473784e-01, 5.38146123e-04, 4.30279821e-02, -4.68525589e-02, # 7.75185153e-02, 7.82282930e-03, 6.74356073e-02, 4.09710407e-02, # -3.60425897e-02, -4.71813440e-01, 5.02379127e-02, 2.02309843e-02, # 5.29680364e-02, 1.68510173e-02, 2.25090146e-01, -4.52307612e-02, # 7.72185996e-02, -2.17333943e-01, 3.30020368e-01, 4.21866514e-02, # 7.15153441e-02, 3.05950731e-01, -3.63454908e-01, -1.28235269e+00, # 5.09610713e-01, 4.65482563e-01, 1.20263052e+00, 5.56594551e-01, # -2.24000740e+00, 3.83565158e-01, 5.31355202e-01, 2.21637583e+00, # -5.63146770e-01, -3.01193684e-01, -4.31942672e-01, 6.85038209e-01, # 3.61178756e-01, 2.76136428e-01, -2.64388829e-01, 0.00000000e+00, # 0.00000000e+00, 0.00000000e+00, 0.00000000e+00, 0.00000000e+00, # 0.00000000e+00]]) mean_body = np.array(joint_limits.axang_limits_patrick / 180 * np.pi).mean(1) body_mean_pose = torch.tensor(mean_body, dtype=dtype).unsqueeze(0) betanet = None if height is not None: betanet = torch.load('models/betanet_old_pytorch.pt') betanet = betanet.to(device=device) betanet.eval() keypoint_data = torch.tensor(keypoints, dtype=dtype) gt_joints = keypoint_data[:, :, :2] if use_joints_conf: joints_conf = keypoint_data[:, :, 2] # Transfer the data to the correct device gt_joints = gt_joints.to(device=device, dtype=dtype) if use_joints_conf: joints_conf = joints_conf.to(device=device, dtype=dtype) scan_tensor = None if scan is not None: scan_tensor = scan.to(device=device) # load pre-computed signed distance field sdf = None sdf_normals = None grid_min = None grid_max = None voxel_size = None # if sdf_penetration: # with open(osp.join(sdf_dir, scene_name + '.json'), 'r') as f: # sdf_data = json.load(f) # grid_min = torch.tensor(np.array(sdf_data['min']), dtype=dtype, device=device) # grid_max = torch.tensor(np.array(sdf_data['max']), dtype=dtype, device=device) # grid_dim = sdf_data['dim'] # voxel_size = (grid_max - grid_min) / grid_dim # sdf = np.load(osp.join(sdf_dir, scene_name + '_sdf.npy')).reshape(grid_dim, grid_dim, grid_dim) # sdf = torch.tensor(sdf, dtype=dtype, device=device) # if osp.exists(osp.join(sdf_dir, scene_name + '_normals.npy')): # sdf_normals = np.load(osp.join(sdf_dir, scene_name + '_normals.npy')).reshape(grid_dim, grid_dim, grid_dim, 3) # sdf_normals = torch.tensor(sdf_normals, dtype=dtype, device=device) # else: # print("Normals not found...") with open(os.path.join(cam2world_dir, scene_name + '.json'), 'r') as f: cam2world = np.array(json.load(f)) R = torch.tensor(cam2world[:3, :3].reshape(3, 3), dtype=dtype, device=device) t = torch.tensor(cam2world[:3, 3].reshape(1, 3), dtype=dtype, device=device) # Create the search tree search_tree = None pen_distance = None filter_faces = None if interpenetration: from mesh_intersection.bvh_search_tree import BVH import mesh_intersection.loss as collisions_loss from mesh_intersection.filter_faces import FilterFaces assert use_cuda, 'Interpenetration term can only be used with CUDA' assert torch.cuda.is_available(), \ 'No CUDA Device! Interpenetration term can only be used' + \ ' with CUDA' search_tree = BVH(max_collisions=max_collisions) pen_distance = \ collisions_loss.DistanceFieldPenetrationLoss( sigma=df_cone_height, point2plane=point2plane, vectorized=True, penalize_outside=penalize_outside) if part_segm_fn: # Read the part segmentation part_segm_fn = os.path.expandvars(part_segm_fn) with open(part_segm_fn, 'rb') as faces_parents_file: face_segm_data = pickle.load(faces_parents_file, encoding='latin1') faces_segm = face_segm_data['segm'] faces_parents = face_segm_data['parents'] # Create the module used to filter invalid collision pairs filter_faces = FilterFaces( faces_segm=faces_segm, faces_parents=faces_parents, ign_part_pairs=ign_part_pairs).to(device=device) # load vertix ids of contact parts contact_verts_ids = ftov = None if contact: contact_verts_ids = [] for part in contact_body_parts: with open(os.path.join(body_segments_dir, part + '.json'), 'r') as f: data = json.load(f) contact_verts_ids.append(list(set(data["verts_ind"]))) contact_verts_ids = np.concatenate(contact_verts_ids) vertices = body_model(return_verts=True, body_pose= torch.zeros((batch_size, 63), dtype=dtype, device=device)).vertices vertices_np = vertices.detach().cpu().numpy().squeeze() body_faces_np = body_model.faces_tensor.detach().cpu().numpy().reshape(-1, 3) m = Mesh(v=vertices_np, f=body_faces_np) ftov = m.faces_by_vertex(as_sparse_matrix=True) ftov = sparse.coo_matrix(ftov) indices = torch.LongTensor(np.vstack((ftov.row, ftov.col))).to(device) values = torch.FloatTensor(ftov.data).to(device) shape = ftov.shape ftov = torch.sparse.FloatTensor(indices, values, torch.Size(shape)) # Read the scene scan if any scene_v = scene_vn = scene_f = None if scene_name is not None: if load_scene: scene = Mesh(filename=os.path.join(scene_dir, scene_name + '.ply')) scene.vn = scene.estimate_vertex_normals() scene_v = torch.tensor(scene.v[np.newaxis, :], dtype=dtype, device=device).contiguous() scene_vn = torch.tensor(scene.vn[np.newaxis, :], dtype=dtype, device=device) scene_f = torch.tensor(scene.f.astype(int)[np.newaxis, :], dtype=torch.long, device=device) # Weights used for the pose prior and the shape prior opt_weights_dict = {'data_weight': data_weights, 'body_pose_weight': body_pose_prior_weights, 'shape_weight': shape_weights} if use_face: opt_weights_dict['face_weight'] = face_joints_weights opt_weights_dict['expr_prior_weight'] = expr_weights opt_weights_dict['jaw_prior_weight'] = jaw_pose_prior_weights if use_hands: opt_weights_dict['hand_weight'] = hand_joints_weights opt_weights_dict['hand_prior_weight'] = hand_pose_prior_weights if interpenetration: opt_weights_dict['coll_loss_weight'] = coll_loss_weights if s2m: opt_weights_dict['s2m_weight'] = s2m_weights if m2s: opt_weights_dict['m2s_weight'] = m2s_weights if sdf_penetration: opt_weights_dict['sdf_penetration_weight'] = sdf_penetration_weights if contact: opt_weights_dict['contact_loss_weight'] = contact_loss_weights keys = opt_weights_dict.keys() opt_weights = [dict(zip(keys, vals)) for vals in zip(*(opt_weights_dict[k] for k in keys if opt_weights_dict[k] is not None))] for weight_list in opt_weights: for key in weight_list: weight_list[key] = torch.tensor(weight_list[key], device=device, dtype=dtype) # load indices of the head of smpl-x model with open( osp.join(body_segments_dir, 'body_mask.json'), 'r') as fp: head_indx = np.array(json.load(fp)) N = body_model.get_num_verts() body_indx = np.setdiff1d(np.arange(N), head_indx) head_mask = np.in1d(np.arange(N), head_indx) body_mask = np.in1d(np.arange(N), body_indx) # The indices of the joints used for the initialization of the camera init_joints_idxs = torch.tensor(init_joints_idxs, device=device) edge_indices = kwargs.get('body_tri_idxs') # which initialization mode to choose: similar traingles, mean of the scan or the average of both if init_mode == 'scan': init_t = init_trans elif init_mode == 'both': init_t = (init_trans.to(device) + fitting.guess_init(body_model, gt_joints, edge_indices, use_vposer=use_vposer, vposer=vposer, pose_embedding=pose_embedding, model_type=kwargs.get('model_type', 'smpl'), focal_length=focal_length_x, dtype=dtype) ) /2.0 else: init_t = fitting.guess_init(body_model, gt_joints, edge_indices, use_vposer=use_vposer, vposer=vposer, pose_embedding=pose_embedding, model_type=kwargs.get('model_type', 'smpl'), focal_length=focal_length_x, dtype=dtype) camera_loss = fitting.create_loss('camera_init', trans_estimation=init_t, init_joints_idxs=init_joints_idxs, depth_loss_weight=depth_loss_weight, camera_mode=camera_mode, dtype=dtype).to(device=device) camera_loss.trans_estimation[:] = init_t loss = fitting.create_loss(loss_type=loss_type, joint_weights=joint_weights, rho=rho, use_joints_conf=use_joints_conf, use_face=use_face, use_hands=use_hands, vposer=vposer, pose_embedding=pose_embedding, body_pose_prior=body_pose_prior, shape_prior=shape_prior, angle_prior=angle_prior, expr_prior=expr_prior, left_hand_prior=left_hand_prior, right_hand_prior=right_hand_prior, jaw_prior=jaw_prior, interpenetration=interpenetration, pen_distance=pen_distance, search_tree=search_tree, tri_filtering_module=filter_faces, s2m=s2m, m2s=m2s, rho_s2m=rho_s2m, rho_m2s=rho_m2s, head_mask=head_mask, body_mask=body_mask, sdf_penetration=sdf_penetration, voxel_size=voxel_size, grid_min=grid_min, grid_max=grid_max, sdf=sdf, sdf_normals=sdf_normals, R=R, t=t, contact=contact, contact_verts_ids=contact_verts_ids, rho_contact=rho_contact, contact_angle=contact_angle, dtype=dtype, betanet=betanet, height=height, weight=weight, gender=gender, weight_w=weight_w, height_w=height_w, **kwargs) loss = loss.to(device=device) with fitting.FittingMonitor(batch_size=batch_size, visualize=visualize, viz_mode=viz_mode, **kwargs) as monitor: img = torch.tensor(img, dtype=dtype) _, H, W, _ = img.shape # Reset the parameters to estimate the initial translation of the # body model if camera_mode == 'moving': body_model.reset_params(body_pose=body_mean_pose) # Update the value of the translation of the camera as well as # the image center. with torch.no_grad(): camera.translation[:] = init_t.view_as(camera.translation) camera.center[:] = torch.tensor([W, H], dtype=dtype) * 0.5 # Re-enable gradient calculation for the camera translation camera.translation.requires_grad = True camera_opt_params = [camera.translation, body_model.global_orient] elif camera_mode == 'fixed': # body_model.reset_params() # body_model.transl[:] = torch.tensor(init_t) # body_model.body_pose[:] = torch.tensor(body_mean_pose) body_model.reset_params(body_pose=body_mean_pose, transl=init_t) camera_opt_params = [body_model.transl, body_model.global_orient] # If the distance between the 2D shoulders is smaller than a # predefined threshold then try 2 fits, the initial one and a 180 # degree rotation shoulder_dist = torch.norm(gt_joints[:, left_shoulder_idx, :] - gt_joints[:, right_shoulder_idx, :], dim=1) try_both_orient = shoulder_dist.min() < side_view_thsh kwargs['lr'] *= 10 camera_optimizer, camera_create_graph = optim_factory.create_optimizer(camera_opt_params, **kwargs) kwargs['lr'] /= 10 # The closure passed to the optimizer fit_camera = monitor.create_fitting_closure( camera_optimizer, body_model, camera, gt_joints, camera_loss, create_graph=camera_create_graph, use_vposer=use_vposer, vposer=vposer, pose_embedding=pose_embedding, scan_tensor=scan_tensor, return_full_pose=False, return_verts=False) # Step 1: Optimize over the torso joints the camera translation # Initialize the computational graph by feeding the initial translation # of the camera and the initial pose of the body model. camera_init_start = time.time() cam_init_loss_val = monitor.run_fitting(camera_optimizer, fit_camera, camera_opt_params, body_model, use_vposer=use_vposer, pose_embedding=pose_embedding, vposer=vposer) if interactive: if use_cuda and torch.cuda.is_available(): torch.cuda.synchronize() tqdm.write('Camera initialization done after {:.4f}'.format( time.time() - camera_init_start)) tqdm.write('Camera initialization final loss {:.4f}'.format( cam_init_loss_val)) # If the 2D detections/positions of the shoulder joints are too # close the rotate the body by 180 degrees and also fit to that # orientation if try_both_orient: with torch.no_grad(): flipped_orient = torch.zeros_like(body_model.global_orient) for i in range(batch_size): body_orient = body_model.global_orient[i, :].detach().cpu().numpy() local_flip = cv2.Rodrigues(body_orient)[0].dot(cv2.Rodrigues(np.array([0., np.pi, 0]))[0]) local_flip = cv2.Rodrigues(local_flip)[0].ravel() flipped_orient[i, :] = torch.Tensor(local_flip).to(device) orientations = [body_model.global_orient, flipped_orient] else: orientations = [body_model.global_orient.detach().cpu().numpy()] # store here the final error for both orientations, # and pick the orientation resulting in the lowest error results = [] body_transl = body_model.transl.clone().detach() # Step 2: Optimize the full model final_loss_val = 0 # for or_idx, orient in enumerate(orientations): or_idx = 0 while or_idx < len(orientations): global_vars.cur_orientation = or_idx orient = orientations[or_idx] print('Trying orientation', or_idx, 'of', len(orientations)) opt_start = time.time() or_idx += 1 new_params = defaultdict(transl=body_transl, global_orient=orient, body_pose=body_mean_pose) body_model.reset_params(**new_params) if use_vposer: with torch.no_grad(): pose_embedding.fill_(0) pose_embedding += torch.tensor(mean_body, dtype=dtype, device=device) for opt_idx, curr_weights in enumerate(opt_weights): global_vars.cur_opt_stage = opt_idx if opt_idx not in trans_opt_stages: body_model.transl.requires_grad = False else: body_model.transl.requires_grad = True body_params = list(body_model.parameters()) final_params = list( filter(lambda x: x.requires_grad, body_params)) if use_vposer: final_params.append(pose_embedding) body_optimizer, body_create_graph = optim_factory.create_optimizer( final_params, **kwargs) body_optimizer.zero_grad() curr_weights['bending_prior_weight'] = ( 3.17 * curr_weights['body_pose_weight']) if use_hands: joint_weights[:, 25:76] = curr_weights['hand_weight'] if use_face: joint_weights[:, 76:] = curr_weights['face_weight'] loss.reset_loss_weights(curr_weights) closure = monitor.create_fitting_closure( body_optimizer, body_model, camera=camera, gt_joints=gt_joints, joints_conf=joints_conf, joint_weights=joint_weights, loss=loss, create_graph=body_create_graph, use_vposer=use_vposer, vposer=vposer, pose_embedding=pose_embedding, scan_tensor=scan_tensor, scene_v=scene_v, scene_vn=scene_vn, scene_f=scene_f,ftov=ftov, return_verts=True, return_full_pose=True) if interactive: if use_cuda and torch.cuda.is_available(): torch.cuda.synchronize() stage_start = time.time() final_loss_val = monitor.run_fitting( body_optimizer, closure, final_params, body_model, pose_embedding=pose_embedding, vposer=vposer, use_vposer=use_vposer) # print('Final loss val', final_loss_val) # if final_loss_val is None or math.isnan(final_loss_val) or math.isnan(global_vars.cur_loss_dict['total']): # break if interactive: if use_cuda and torch.cuda.is_available(): torch.cuda.synchronize() elapsed = time.time() - stage_start if interactive: tqdm.write('Stage {:03d} done after {:.4f} seconds'.format( opt_idx, elapsed)) # if final_loss_val is None or math.isnan(final_loss_val) or math.isnan(global_vars.cur_loss_dict['total']): # print('Optimization FAILURE, retrying') # orientations.append(orientations[or_idx-1] * 0.9) # continue if interactive: if use_cuda and torch.cuda.is_available(): torch.cuda.synchronize() elapsed = time.time() - opt_start tqdm.write('Body fitting Orientation {} done after {:.4f} seconds'.format(or_idx, elapsed)) tqdm.write('Body final loss val = {:.5f}'.format(final_loss_val)) # Get the result of the fitting process # Store in it the errors list in order to compare multiple # orientations, if they exist result = {'camera_' + str(key): val.detach().cpu().numpy() for key, val in camera.named_parameters()} result['camera_focal_length_x'] = camera.focal_length_x.detach().cpu().numpy() result['camera_focal_length_y'] = camera.focal_length_y.detach().cpu().numpy() result['camera_center'] = camera.center.detach().cpu().numpy() result.update({key: val.detach().cpu().numpy() for key, val in body_model.named_parameters()}) if use_vposer: result['pose_embedding'] = pose_embedding.detach().cpu().numpy() body_pose = vposer.decode(pose_embedding, output_type='aa').view(1, -1) if use_vposer else None if "smplx.body_models.SMPL'" in str(type(body_model)): wrist_pose = torch.zeros([body_pose.shape[0], 6], dtype=body_pose.dtype, device=body_pose.device) body_pose = torch.cat([body_pose, wrist_pose], dim=1) result['body_pose'] = body_pose.detach().cpu().numpy() result['final_loss_val'] = final_loss_val result['loss_dict'] = global_vars.cur_loss_dict result['betanet_weight'] = global_vars.cur_weight result['betanet_height'] = global_vars.cur_height result['gt_joints'] = gt_joints.detach().cpu().numpy() result['max_joint'] = global_vars.cur_max_joint results.append(result) for idx, res_folder in enumerate(result_fn): # Iterate over batch pkl_data = {} min_loss = np.inf all_results = [] for result in results: # Iterate over orientations sel_res = misc_utils.get_data_from_batched_dict(result, idx, len(result_fn)) all_results.append(sel_res) cost = sel_res['loss_dict']['total'] + sel_res['loss_dict']['pprior'] * 60 if cost < min_loss: min_loss = cost pkl_data.update(sel_res) pkl_data['all_results'] = all_results with open(res_folder, 'wb') as result_file: pickle.dump(pkl_data, result_file, protocol=2) img_s = img[idx, :].detach().cpu().numpy() img_s = pil_img.fromarray((img_s * 255).astype(np.uint8)) img_s.save(out_img_fn[idx])
def main(): description = 'Example script for untangling Mesh self intersections' parser = argparse.ArgumentParser(description=description, prog='Batch Mesh Untangle') parser.add_argument('--data_folder', type=str, default='data', help='The path to data') parser.add_argument('--point2plane', default=False, type=lambda arg: arg.lower() in ['true', '1'], help='Use point to distance') parser.add_argument('--sigma', default=0.5, type=float, help='The height of the cone used to calculate the' + ' distance field loss') parser.add_argument('--lr', default=1, type=float, help='The learning rate for SGD') parser.add_argument('--coll_loss_weight', default=1e-4, type=float, help='The weight for the collision loss') parser.add_argument('--verts_reg_weight', default=1e-5, type=float, help='The weight for the verts regularizer') parser.add_argument('--max_collisions', default=8, type=int, help='The maximum number of bounding box collisions') parser.add_argument('--iterations', default=100, type=int, help='Number of optimization iterations') args = parser.parse_args() data_folder = args.data_folder point2plane = args.point2plane lr = args.lr coll_loss_weight = args.coll_loss_weight max_collisions = args.max_collisions sigma = args.sigma iterations = args.iterations device = torch.device('cuda') obj_paths = sorted(glob(data_folder+'/*.obj')) verts_list,faces_list = [],[] for obj_path in obj_paths: cur_verts, _, cur_faces = readObj(obj_path) cur_faces -= 1 verts_list.append(cur_verts) faces_list.append(cur_faces) all_verts = np.array(verts_list).astype(np.float32) all_faces = np.array(faces_list).astype(np.int64) verts_tensor = torch.tensor(all_verts.copy(), dtype=torch.float32, device=device) face_tensor = torch.tensor(all_faces, dtype=torch.long, device=device) param = torch.tensor(all_verts.copy(), dtype=torch.float32, device=device).requires_grad_() bs, nv = verts_tensor.shape[:2] bs, nf = face_tensor.shape[:2] faces_idx = face_tensor + \ (torch.arange(bs, dtype=torch.long).to(device) * nv)[:, None, None] # Create the search tree search_tree = BVH(max_collisions=max_collisions) pen_distance = \ collisions_loss.DistanceFieldPenetrationLoss(sigma=sigma, point2plane=point2plane, vectorized=True) mse_loss = nn.MSELoss(reduction='sum').to(device=device) optimizer = torch.optim.SGD([param], lr=lr) step = 0 for i in range(iterations) optimizer.zero_grad() triangles = param.view([-1, 3])[faces_idx] with torch.no_grad(): collision_idxs = search_tree(triangles) pen_loss = coll_loss_weight * \ pen_distance(triangles, collision_idxs) verts_reg_loss = torch.tensor(0, device=device, dtype=torch.float32) if verts_reg_weight > 0: verts_reg_loss = verts_reg_weight * \ mse_loss(param, verts_tensor) loss = pen_loss + verts_reg_loss np_loss = loss.detach().cpu().squeeze().tolist() if type(np_loss) != list: np_loss = [np_loss] msg = '{:.5f} ' * len(np_loss) print('Loss per model:', msg.format(*np_loss)) loss.backward(torch.ones_like(loss)) optimizer.step() step += 1 optimized_verts = param.detach().cpu().numpy() for i in range(optimized_verts.shape[0]):
def fit_single_frame(img, keypoints, body_model, camera, joint_weights, body_pose_prior, jaw_prior, left_hand_prior, right_hand_prior, shape_prior, expr_prior, angle_prior, result_fn='out.pkl', mesh_fn='out.obj', out_img_fn='overlay.png', loss_type='smplify', use_cuda=True, init_joints_idxs=(9, 12, 2, 5), use_face=True, use_hands=True, data_weights=None, body_pose_prior_weights=None, hand_pose_prior_weights=None, jaw_pose_prior_weights=None, shape_weights=None, expr_weights=None, hand_joints_weights=None, face_joints_weights=None, depth_loss_weight=1e2, interpenetration=True, coll_loss_weights=None, df_cone_height=0.5, penalize_outside=True, max_collisions=8, point2plane=False, part_segm_fn='', focal_length=5000., side_view_thsh=25., rho=100, vposer_latent_dim=32, vposer_ckpt='', use_joints_conf=False, interactive=True, visualize=False, save_meshes=True, degrees=None, batch_size=1, dtype=torch.float32, ign_part_pairs=None, left_shoulder_idx=2, right_shoulder_idx=5, **kwargs): assert batch_size == 1, 'PyTorch L-BFGS only supports batch_size == 1' device = torch.device('cuda') if use_cuda else torch.device('cpu') if degrees is None: degrees = [0, 90, 180, 270] if data_weights is None: data_weights = [ 1, ] * 5 if body_pose_prior_weights is None: body_pose_prior_weights = [4.04 * 1e2, 4.04 * 1e2, 57.4, 4.78] msg = ('Number of Body pose prior weights {}'.format( len(body_pose_prior_weights)) + ' does not match the number of data term weights {}'.format( len(data_weights))) assert (len(data_weights) == len(body_pose_prior_weights)), msg if use_hands: if hand_pose_prior_weights is None: hand_pose_prior_weights = [1e2, 5 * 1e1, 1e1, .5 * 1e1] msg = ('Number of Body pose prior weights does not match the' + ' number of hand pose prior weights') assert ( len(hand_pose_prior_weights) == len(body_pose_prior_weights)), msg if hand_joints_weights is None: hand_joints_weights = [0.0, 0.0, 0.0, 1.0] msg = ('Number of Body pose prior weights does not match the' + ' number of hand joint distance weights') assert ( len(hand_joints_weights) == len(body_pose_prior_weights)), msg if shape_weights is None: shape_weights = [1e2, 5 * 1e1, 1e1, .5 * 1e1] msg = ('Number of Body pose prior weights = {} does not match the' + ' number of Shape prior weights = {}') assert (len(shape_weights) == len(body_pose_prior_weights)), msg.format( len(shape_weights), len(body_pose_prior_weights)) if use_face: if jaw_pose_prior_weights is None: jaw_pose_prior_weights = [[x] * 3 for x in shape_weights] else: jaw_pose_prior_weights = map(lambda x: map(float, x.split(',')), jaw_pose_prior_weights) jaw_pose_prior_weights = [list(w) for w in jaw_pose_prior_weights] msg = ('Number of Body pose prior weights does not match the' + ' number of jaw pose prior weights') assert ( len(jaw_pose_prior_weights) == len(body_pose_prior_weights)), msg if expr_weights is None: expr_weights = [1e2, 5 * 1e1, 1e1, .5 * 1e1] msg = ('Number of Body pose prior weights = {} does not match the' + ' number of Expression prior weights = {}') assert (len(expr_weights) == len(body_pose_prior_weights)), msg.format( len(body_pose_prior_weights), len(expr_weights)) if face_joints_weights is None: face_joints_weights = [0.0, 0.0, 0.0, 1.0] msg = ('Number of Body pose prior weights does not match the' + ' number of face joint distance weights') assert (len(face_joints_weights) == len(body_pose_prior_weights)), msg if coll_loss_weights is None: coll_loss_weights = [0.0] * len(body_pose_prior_weights) msg = ('Number of Body pose prior weights does not match the' + ' number of collision loss weights') assert (len(coll_loss_weights) == len(body_pose_prior_weights)), msg use_vposer = kwargs.get('use_vposer', True) vposer, pose_embedding = [ None, ] * 2 if use_vposer: pose_embedding = torch.zeros([batch_size, 32], dtype=dtype, device=device, requires_grad=True) vposer_ckpt = osp.expandvars(vposer_ckpt) vposer, _ = load_vposer(vposer_ckpt, vp_model='snapshot') vposer = vposer.to(device=device) vposer.eval() if use_vposer: body_mean_pose = torch.zeros([batch_size, vposer_latent_dim], dtype=dtype) else: body_mean_pose = body_pose_prior.get_mean().detach().cpu() keypoint_data = torch.tensor(keypoints, dtype=dtype) gt_joints = keypoint_data[:, :, :2] if use_joints_conf: joints_conf = keypoint_data[:, :, 2].reshape(1, -1) # Transfer the data to the correct device gt_joints = gt_joints.to(device=device, dtype=dtype) if use_joints_conf: joints_conf = joints_conf.to(device=device, dtype=dtype) # Create the search tree search_tree = None pen_distance = None filter_faces = None if interpenetration: from mesh_intersection.bvh_search_tree import BVH import mesh_intersection.loss as collisions_loss from mesh_intersection.filter_faces import FilterFaces assert use_cuda, 'Interpenetration term can only be used with CUDA' assert torch.cuda.is_available(), \ 'No CUDA Device! Interpenetration term can only be used' + \ ' with CUDA' search_tree = BVH(max_collisions=max_collisions) pen_distance = \ collisions_loss.DistanceFieldPenetrationLoss( sigma=df_cone_height, point2plane=point2plane, vectorized=True, penalize_outside=penalize_outside) if part_segm_fn: # Read the part segmentation part_segm_fn = os.path.expandvars(part_segm_fn) with open(part_segm_fn, 'rb') as faces_parents_file: face_segm_data = pickle.load(faces_parents_file, encoding='latin1') faces_segm = face_segm_data['segm'] faces_parents = face_segm_data['parents'] # Create the module used to filter invalid collision pairs filter_faces = FilterFaces( faces_segm=faces_segm, faces_parents=faces_parents, ign_part_pairs=ign_part_pairs).to(device=device) # Weights used for the pose prior and the shape prior opt_weights_dict = { 'data_weight': data_weights, 'body_pose_weight': body_pose_prior_weights, 'shape_weight': shape_weights } if use_face: opt_weights_dict['face_weight'] = face_joints_weights opt_weights_dict['expr_prior_weight'] = expr_weights opt_weights_dict['jaw_prior_weight'] = jaw_pose_prior_weights if use_hands: opt_weights_dict['hand_weight'] = hand_joints_weights opt_weights_dict['hand_prior_weight'] = hand_pose_prior_weights if interpenetration: opt_weights_dict['coll_loss_weight'] = coll_loss_weights keys = opt_weights_dict.keys() opt_weights = [ dict(zip(keys, vals)) for vals in zip(*(opt_weights_dict[k] for k in keys if opt_weights_dict[k] is not None)) ] for weight_list in opt_weights: for key in weight_list: weight_list[key] = torch.tensor(weight_list[key], device=device, dtype=dtype) # The indices of the joints used for the initialization of the camera init_joints_idxs = torch.tensor(init_joints_idxs, device=device) edge_indices = kwargs.get('body_tri_idxs') init_t = fitting.guess_init(body_model, gt_joints, edge_indices, use_vposer=use_vposer, vposer=vposer, pose_embedding=pose_embedding, model_type=kwargs.get('model_type', 'smpl'), focal_length=focal_length, dtype=dtype) camera_loss = fitting.create_loss('camera_init', trans_estimation=init_t, init_joints_idxs=init_joints_idxs, depth_loss_weight=depth_loss_weight, dtype=dtype).to(device=device) camera_loss.trans_estimation[:] = init_t loss = fitting.create_loss(loss_type=loss_type, joint_weights=joint_weights, rho=rho, use_joints_conf=use_joints_conf, use_face=use_face, use_hands=use_hands, vposer=vposer, pose_embedding=pose_embedding, body_pose_prior=body_pose_prior, shape_prior=shape_prior, angle_prior=angle_prior, expr_prior=expr_prior, left_hand_prior=left_hand_prior, right_hand_prior=right_hand_prior, jaw_prior=jaw_prior, interpenetration=interpenetration, pen_distance=pen_distance, search_tree=search_tree, tri_filtering_module=filter_faces, dtype=dtype, **kwargs) loss = loss.to(device=device) with fitting.FittingMonitor(batch_size=batch_size, visualize=visualize, **kwargs) as monitor: img = torch.tensor(img, dtype=dtype) H, W, _ = img.shape data_weight = 1000 / H # The closure passed to the optimizer camera_loss.reset_loss_weights({'data_weight': data_weight}) # Reset the parameters to estimate the initial translation of the # body model body_model.reset_params(body_pose=body_mean_pose) # If the distance between the 2D shoulders is smaller than a # predefined threshold then try 2 fits, the initial one and a 180 # degree rotation shoulder_dist = torch.dist(gt_joints[:, left_shoulder_idx], gt_joints[:, right_shoulder_idx]) try_both_orient = shoulder_dist.item() < side_view_thsh # Update the value of the translation of the camera as well as # the image center. with torch.no_grad(): camera.translation[:] = init_t.view_as(camera.translation) camera.center[:] = torch.tensor([W, H], dtype=dtype) * 0.5 # Re-enable gradient calculation for the camera translation camera.translation.requires_grad = True camera_opt_params = [camera.translation, body_model.global_orient] camera_optimizer, camera_create_graph = optim_factory.create_optimizer( camera_opt_params, **kwargs) # The closure passed to the optimizer fit_camera = monitor.create_fitting_closure( camera_optimizer, body_model, camera, gt_joints, camera_loss, create_graph=camera_create_graph, use_vposer=use_vposer, vposer=vposer, pose_embedding=pose_embedding, return_full_pose=False, return_verts=False) # Step 1: Optimize over the torso joints the camera translation # Initialize the computational graph by feeding the initial translation # of the camera and the initial pose of the body model. camera_init_start = time.time() cam_init_loss_val = monitor.run_fitting(camera_optimizer, fit_camera, camera_opt_params, body_model, use_vposer=use_vposer, pose_embedding=pose_embedding, vposer=vposer) if interactive: if use_cuda and torch.cuda.is_available(): torch.cuda.synchronize() tqdm.write('Camera initialization done after {:.4f}'.format( time.time() - camera_init_start)) tqdm.write('Camera initialization final loss {:.4f}'.format( cam_init_loss_val)) # If the 2D detections/positions of the shoulder joints are too # close the rotate the body by 180 degrees and also fit to that # orientation if try_both_orient: body_orient = body_model.global_orient.detach().cpu().numpy() flipped_orient = cv2.Rodrigues(body_orient)[0].dot( cv2.Rodrigues(np.array([0., np.pi, 0]))[0]) flipped_orient = cv2.Rodrigues(flipped_orient)[0].ravel() flipped_orient = torch.tensor(flipped_orient, dtype=dtype, device=device).unsqueeze(dim=0) orientations = [body_orient, flipped_orient] else: orientations = [body_model.global_orient.detach().cpu().numpy()] # store here the final error for both orientations, # and pick the orientation resulting in the lowest error results = [] # Step 2: Optimize the full model final_loss_val = 0 for or_idx, orient in enumerate(tqdm(orientations, desc='Orientation')): opt_start = time.time() new_params = defaultdict(global_orient=orient, body_pose=body_mean_pose) body_model.reset_params(**new_params) if use_vposer: with torch.no_grad(): pose_embedding.fill_(0) for opt_idx, curr_weights in enumerate( tqdm(opt_weights, desc='Stage')): body_params = list(body_model.parameters()) final_params = list( filter(lambda x: x.requires_grad, body_params)) if use_vposer: final_params.append(pose_embedding) body_optimizer, body_create_graph = optim_factory.create_optimizer( final_params, **kwargs) body_optimizer.zero_grad() curr_weights['data_weight'] = data_weight curr_weights['bending_prior_weight'] = ( 3.17 * curr_weights['body_pose_weight']) if use_hands: joint_weights[:, 25:76] = curr_weights['hand_weight'] if use_face: joint_weights[:, 76:] = curr_weights['face_weight'] loss.reset_loss_weights(curr_weights) closure = monitor.create_fitting_closure( body_optimizer, body_model, camera=camera, gt_joints=gt_joints, joints_conf=joints_conf, joint_weights=joint_weights, loss=loss, create_graph=body_create_graph, use_vposer=use_vposer, vposer=vposer, pose_embedding=pose_embedding, return_verts=True, return_full_pose=True) if interactive: if use_cuda and torch.cuda.is_available(): torch.cuda.synchronize() stage_start = time.time() final_loss_val = monitor.run_fitting( body_optimizer, closure, final_params, body_model, pose_embedding=pose_embedding, vposer=vposer, use_vposer=use_vposer) if interactive: if use_cuda and torch.cuda.is_available(): torch.cuda.synchronize() elapsed = time.time() - stage_start if interactive: tqdm.write( 'Stage {:03d} done after {:.4f} seconds'.format( opt_idx, elapsed)) if interactive: if use_cuda and torch.cuda.is_available(): torch.cuda.synchronize() elapsed = time.time() - opt_start tqdm.write( 'Body fitting Orientation {} done after {:.4f} seconds'. format(or_idx, elapsed)) tqdm.write( 'Body final loss val = {:.5f}'.format(final_loss_val)) # Get the result of the fitting process # Store in it the errors list in order to compare multiple # orientations, if they exist result = { 'camera_' + str(key): val.detach().cpu().numpy() for key, val in camera.named_parameters() } result.update({ key: val.detach().cpu().numpy() for key, val in body_model.named_parameters() }) if use_vposer: result['body_pose'] = pose_embedding.detach().cpu().numpy() results.append({'loss': final_loss_val, 'result': result}) with open(result_fn, 'wb') as result_file: if len(results) > 1: min_idx = (0 if results[0]['loss'] < results[1]['loss'] else 1) else: min_idx = 0 pickle.dump(results[min_idx]['result'], result_file, protocol=2) if save_meshes or visualize: body_pose = vposer.decode(pose_embedding, output_type='aa').view( 1, -1) if use_vposer else None model_type = kwargs.get('model_type', 'smpl') append_wrists = model_type == 'smpl' and use_vposer if append_wrists: wrist_pose = torch.zeros([body_pose.shape[0], 6], dtype=body_pose.dtype, device=body_pose.device) body_pose = torch.cat([body_pose, wrist_pose], dim=1) model_output = body_model(return_verts=True, body_pose=body_pose) vertices = model_output.vertices.detach().cpu().numpy().squeeze() import trimesh out_mesh = trimesh.Trimesh(vertices, body_model.faces, process=False) rot = trimesh.transformations.rotation_matrix(np.radians(180), [1, 0, 0]) out_mesh.apply_transform(rot) out_mesh.export(mesh_fn) if visualize: import pyrender material = pyrender.MetallicRoughnessMaterial( metallicFactor=0.0, alphaMode='OPAQUE', baseColorFactor=(1.0, 1.0, 0.9, 1.0)) mesh = pyrender.Mesh.from_trimesh(out_mesh, material=material) scene = pyrender.Scene(bg_color=[0.0, 0.0, 0.0, 0.0], ambient_light=(0.3, 0.3, 0.3)) scene.add(mesh, 'mesh') camera_center = camera.center.detach().cpu().numpy().squeeze() camera_transl = camera.translation.detach().cpu().numpy().squeeze() # Equivalent to 180 degrees around the y-axis. Transforms the fit to # OpenGL compatible coordinate system. camera_transl[0] *= -1.0 camera_pose = np.eye(4) camera_pose[:3, 3] = camera_transl camera = pyrender.camera.IntrinsicsCamera(fx=focal_length, fy=focal_length, cx=camera_center[0], cy=camera_center[1]) scene.add(camera, pose=camera_pose) # Get the lights from the viewer light_nodes = monitor.mv.viewer._create_raymond_lights() for node in light_nodes: scene.add_node(node) r = pyrender.OffscreenRenderer(viewport_width=W, viewport_height=H, point_size=1.0) color, _ = r.render(scene, flags=pyrender.RenderFlags.RGBA) color = color.astype(np.float32) / 255.0 valid_mask = (color[:, :, -1] > 0)[:, :, np.newaxis] input_img = img.detach().cpu().numpy() output_img = (color[:, :, :-1] * valid_mask + (1 - valid_mask) * input_img) img = pil_img.fromarray((output_img * 255).astype(np.uint8)) img.save(out_img_fn)