def forward(self, ref_points, points, *args, **kwargs): B, N, C = ref_points.shape # TODO replace with ball query # (B,P,K,3), (B,P,K), (B,P,K) ref_grouped_points, ref_group_idx, ref_group_dist = faiss_knn( self.nn_size, ref_points, ref_points, NCHW=False) mask = (ref_group_dist < self.ball_size2) ref_grouped_points.masked_fill_(~mask.unsqueeze(-1), 0.0) # number of points inside the ball (B,P,1) nball = torch.sum(mask.to(torch.float), dim=-1, keepdim=True) ref_group_center = torch.sum(ref_grouped_points, dim=2, keepdim=True) / nball.unsqueeze(-1) # B,P,K,3 ref_points = ref_grouped_points - ref_group_center ref_allpoints = ref_points.view(-1, self.nn_size, C).contiguous() U_ref, S_ref, V_ref = batch_svd(ref_allpoints) ref_cond = S_ref[:, 0] / (S_ref[:, -1] + S_ref[:, 0]) ref_cond = ref_cond.view(B, N).contiguous() # grouped_points, group_idx, _ = faiss_knn(self.nn_size, points, points, NCHW=False) grouped_points = torch.gather( points.unsqueeze(1).expand(-1, N, -1, -1), 2, ref_group_idx.unsqueeze(-1).expand(-1, -1, -1, C)) grouped_points.masked_fill(~mask.unsqueeze(-1), 0.0) group_center = torch.sum(grouped_points, dim=2, keepdim=True) / nball.unsqueeze(-1) points = grouped_points - group_center allpoints = points.view(-1, self.nn_size, C).contiguous() # S (BN, k) U, S, V = batch_svd(allpoints) cond = S[:, 0] / (S[:, -1] + S[:, 0]) cond = cond.view(B, N).contiguous() return self.metric(cond, ref_cond)
def forward(self, cage_v, cage_f, shape, shape_vn, epsilon=0.01, interpolate=True): B, M, D = cage_v.shape B, F, _ = cage_f.shape B, N, _ = shape.shape self.sample_weights = self.sample_weights.to(device=shape.device) # B,FF,_ = shape_f.shape # sample points using interpolated barycentric weights on cage triangles (B,F,1,3,3) cage_face_vertices = torch.gather( cage_v, 1, cage_f.reshape(B, F * 3, 1).expand(-1, -1, cage_v.shape[-1])).reshape(B, F, 1, 3, 3) sample_weights = self.sample_weights.unsqueeze(0).unsqueeze( 0).unsqueeze(-1).to(device=cage_v.device) # (1,1,S,3,1) # (B,F,S,3) cage_sampled_points = torch.sum(sample_weights * cage_face_vertices, dim=-2).reshape(B, -1, 3) # shape_face_vertices = torch.gather(shape, 1, shape_f.view(B,F*3,1)).view(B,F,3,3) # find the closest point on the shape nn_point, nn_index, _ = faiss_knn(1, cage_sampled_points, shape, NCHW=False) nn_point = nn_point.squeeze(2) # (B,FS,1) nn_normal = torch.gather( shape_vn.unsqueeze(1).expand(-1, nn_index.shape[1], -1, -1), 2, nn_index.unsqueeze(-1).expand(-1, -1, -1, shape_vn.shape[-1])) nn_normal = nn_normal.squeeze(2) # if <(q-p), n> is negative, then this point is inside the shape, gradient is along the normal direction dot = dot_product(cage_sampled_points - nn_point - epsilon * nn_normal, nn_normal, dim=-1) loss = torch.where(dot < 0, -dot, torch.zeros_like(dot)) if self.reduction == "mean": return loss.mean() elif self.reduction == "max": return torch.mean(torch.max(loss, dim=-1)[0]) elif self.reduction == "sum": return loss.mean(torch.sum(loss, dim=-1)) elif self.reduction == "none": return loss else: raise NotImplementedError return loss
def forward(self, cage, shape, shape_normals, epsilon=0.01, interpolate=True): """ Penalize polygon cage that is inside the given shape Args: cage: (B,M,3) shape: (B,N,3) shape_normals: (B,N,3) return: """ B, M, D = cage.shape interpolate_n = 10 # find the closest point on the shape cage_p = cage[:, [i for i in range(1, M)] + [0], :] t = torch.linspace(0, 1, interpolate_n).to(device=cage_p.device) # B,M,K,3 cage_itp = t.reshape([1, 1, interpolate_n, 1])*cage_p.unsqueeze(2).expand(-1, -1, interpolate_n, -1) + \ (1-t.reshape([1, 1, interpolate_n, 1]))*cage.unsqueeze(2).expand(-1, -1, interpolate_n, -1) cage_itp = cage_itp.reshape(B, -1, D) nn_point, nn_index, _ = faiss_knn(1, cage_itp, shape, NCHW=False) nn_point = nn_point.squeeze(2) nn_normal = torch.gather( shape_normals.unsqueeze(1).expand(-1, nn_index.shape[1], -1, -1), 2, nn_index.unsqueeze(-1).expand(-1, -1, -1, shape_normals.shape[-1])) nn_normal = nn_normal.squeeze(2) # if <(q-p), n> is negative, then this point is inside the shape, gradient is along the normal direction dot = dot_product(cage_itp - nn_point - epsilon * nn_normal, nn_normal, dim=-1) loss = torch.where(dot < 0, -dot, torch.zeros_like(dot)) if self.reduction == "mean": return loss.mean() elif self.reduction == "max": return torch.mean(torch.max(loss, dim=-1)[0]) elif self.reduction == "sum": return loss.mean(torch.sum(loss, dim=-1)) elif self.reduction == "none": return loss else: raise NotImplementedError return loss
target_shape.unsqueeze_(0) orig_label = pd.read_csv(orig_label_path, delimiter=" ", skiprows=1, header=None) orig_label_name = orig_label.iloc[:, 5] source_points = torch.from_numpy(orig_label.iloc[:, 6:9].to_numpy().astype( np.float32)) source_points = source_points.unsqueeze(0) # find the closest point on the original meshes source_mesh = om.read_polymesh(source_model) # source_mesh = om.read_trimesh(source_model) source_shape_arr = source_mesh.points() source_shape = source_shape_arr.copy() source_shape = torch.from_numpy(source_shape[None, :, :3]).float() _, idx, _ = faiss_knn(1, source_points, source_shape, NCHW=False) target_points = torch.gather( target_shape.unsqueeze(1).expand(-1, source_points.shape[1], -1, -1), 2, idx.unsqueeze(-1).expand(-1, -1, -1, 3)) # save to pd again orig_label[9] = idx.squeeze(0).squeeze(-1) ncol = orig_label.shape[1] orig_label.to_csv(orig_label_path, sep=" ", header=[str(orig_label.shape[0])] + [""] * (ncol - 1), index=False) orig_label.iloc[:, 6:9] = target_points.squeeze().numpy() orig_label.to_csv(new_lable, sep=" ", header=[str(orig_label.shape[0])] + [""] * (ncol - 1),
def optimize(opt): """ weights are the same with the original source mesh target=net(old_source) """ # load new target if opt.is_poly: target_mesh = om.read_polymesh(opt.model) else: target_mesh = om.read_trimesh(opt.model) target_shape_arr = target_mesh.points() target_shape = target_shape_arr.copy() target_shape = torch.from_numpy( target_shape[:, :3].astype(np.float32)).cuda() target_shape.unsqueeze_(0) states = torch.load(opt.ckpt) if "states" in states: states = states["states"] cage_v = states["template_vertices"].transpose(1, 2).cuda() cage_f = states["template_faces"].cuda() shape_v = states["source_vertices"].transpose(1, 2).cuda() shape_f = states["source_faces"].cuda() if os.path.isfile(opt.model.replace(os.path.splitext(opt.model)[1], ".picked")) and os.path.isfile(opt.source_model.replace(os.path.splitext(opt.source_model)[1], ".picked")): new_label_path = opt.model.replace(os.path.splitext(opt.model)[1], ".picked") orig_label_path = opt.source_model.replace(os.path.splitext(opt.source_model)[1], ".picked") logger.info("Loading picked labels {} and {}".format(orig_label_path, new_label_path)) import pandas as pd new_label = pd.read_csv(new_label_path, delimiter=" ",skiprows=1, header=None) orig_label = pd.read_csv(orig_label_path, delimiter=" ",skiprows=1, header=None) orig_label_name = orig_label.iloc[:,5] new_label_name = new_label.iloc[:,5].tolist() new_to_orig_idx = [] for i, name in enumerate(new_label_name): matched_idx = orig_label_name[orig_label_name==name].index if matched_idx.size == 1: new_to_orig_idx.append((i, matched_idx[0])) new_to_orig_idx = np.array(new_to_orig_idx) if new_label.shape[1] == 10: new_vidx = new_label.iloc[:,9].to_numpy()[new_to_orig_idx[:,0]] target_points = target_shape[:, new_vidx, :] else: new_label_points = torch.from_numpy(new_label.iloc[:,6:9].to_numpy().astype(np.float32)) target_points = new_label_points.unsqueeze(0).cuda() target_points, new_vidx, _ = faiss_knn(1, target_points, target_shape, NCHW=False) target_points = target_points.squeeze(2) # B,N,3 new_label[9] = new_vidx.squeeze(0).squeeze(-1).cpu().numpy() new_label.to_csv(new_label_path, sep=" ", header=[str(new_label.shape[0])]+[""]*(new_label.shape[1]-1), index=False) target_points = target_points[:, new_to_orig_idx[:,0], :] target_points = target_points.cuda() source_shape, _ = read_trimesh(opt.source_model) source_shape = torch.from_numpy(source_shape[None, :,:3]).float() if orig_label.shape[1] == 10: orig_vidx = orig_label.iloc[:,9].to_numpy()[new_to_orig_idx[:,1]] source_points = source_shape[:, orig_vidx, :] else: orig_label_points = torch.from_numpy(orig_label.iloc[:,6:9].to_numpy().astype(np.float32)) source_points = orig_label_points.unsqueeze(0) # find the closest point on the original meshes source_points, new_vidx, _ = faiss_knn(1, source_points, source_shape, NCHW=False) source_points = source_points.squeeze(2) # B,N,3 orig_label[9] = new_vidx.squeeze(0).squeeze(-1).cpu().numpy() orig_label.to_csv(orig_label_path, sep=" ", header=[str(orig_label.shape[0])]+[""]*(orig_label.shape[1]-1), index=False) source_points = source_points[:,new_to_orig_idx[:,1],:] _, source_center, _ = center_bounding_box(source_shape[0]) source_points -= source_center source_points = source_points.cuda() # # shift target so that the belly match # try: # orig_bellyUp_idx = orig_label_name[orig_label_name=="bellUp"].index[0] # orig_bellyUp = orig_label_points[orig_bellyUp_idx, :] # new_bellyUp_idx = [i for i, i2 in new_to_orig_idx if i2==orig_bellyUp_idx][0] # new_bellyUp = new_label_points[new_bellyUp_idx,:] # target_points += (orig_bellyUp - new_bellyUp) # except Exception as e: # logger.warn("Couldn\'t match belly to belly") # traceback.print_exc(file=sys.stdout) # source_points[0] = center_bounding_box(source_points[0])[0] elif not os.path.isfile(opt.model.replace(os.path.splitext(opt.model)[1], ".picked")) and os.path.isfile(opt.source_model.replace(os.path.splitext(opt.source_model)[1], ".picked")): logger.info("Assuming Faust model") orig_label_path = opt.source_model.replace(os.path.splitext(opt.source_model)[1], ".picked") logger.info("Loading picked labels {}".format(orig_label_path)) import pandas as pd orig_label = pd.read_csv(orig_label_path, delimiter=" ",skiprows=1, header=None) orig_label_name = orig_label.iloc[:,5] source_shape, _ = read_trimesh(opt.source_model) source_shape = torch.from_numpy(source_shape[None, :,:3]).cuda().float() if orig_label.shape[1] == 10: idx = torch.from_numpy(orig_label.iloc[:,9].to_numpy()).long() source_points = source_shape[:,idx,:] target_points = target_shape[:,idx,:] else: source_points = torch.from_numpy(orig_label.iloc[:,6:9].to_numpy().astype(np.float32)) source_points = source_points.unsqueeze(0).cuda() # find the closest point on the original meshes source_points, idx, _ = faiss_knn(1, source_points, source_shape, NCHW=False) source_points = source_points.squeeze(2) # B,N,3 idx = idx.squeeze(-1) target_points = target_shape[:,idx,:] _, source_center, _ = center_bounding_box(source_shape[0]) source_points -= source_center elif opt.corres_idx is None and target_shape.shape[1] == shape_v.shape[1]: logger.info("No correspondence provided, assuming registered Faust models") # corresp_idx = torch.randint(0, shape_f.shape[1], (100,)).cuda() corresp_v = torch.unique(torch.randint(0, shape_v.shape[1], (4800,))).cuda() target_points = torch.index_select(target_shape, 1, corresp_v) source_points = torch.index_select(shape_v, 1, corresp_v) target_shape[0], target_center, target_scale = center_bounding_box(target_shape[0]) _, _, source_scale = center_bounding_box(shape_v[0]) target_scale_factor = (source_scale/target_scale)[1] target_shape *= target_scale_factor target_points -= target_center target_points = (target_points*target_scale_factor).detach() # make sure test use the normalized target_shape_arr[:] = target_shape[0].cpu().numpy() om.write_mesh(os.path.join(opt.log_dir, opt.subdir, os.path.splitext( os.path.basename(opt.model))[0]+"_normalized.obj"), target_mesh) opt.model = os.path.join(opt.log_dir, opt.subdir, os.path.splitext( os.path.basename(opt.model))[0]+"_normalized.obj") pymesh.save_mesh_raw(os.path.join(opt.log_dir, opt.subdir, "template-initial.obj"), shape_v[0].cpu().numpy(), shape_f[0].cpu().numpy()) pymesh.save_mesh_raw(os.path.join(opt.log_dir, opt.subdir, "cage-initial.obj"), cage_v[0].cpu().numpy(), cage_f[0].cpu().numpy()) save_ply(target_points[0].cpu().numpy(), os.path.join( opt.log_dir, opt.subdir, "target_points.ply")) save_ply(source_points[0].cpu().numpy(), os.path.join( opt.log_dir, opt.subdir, "source_points.ply")) logger.info("Optimizing for {} corresponding vertices".format( target_points.shape[1])) cage_init = cage_v.clone().detach() lap_loss = MeshLaplacianLoss(torch.nn.MSELoss(reduction="none"), use_cot=True, use_norm=True, consistent_topology=True, precompute_L=True) mvc_reg_loss = MVCRegularizer(threshold=50, beta=1.0, alpha=0.0) cage_v.requires_grad_(True) optimizer = torch.optim.Adam([cage_v], lr=opt.lr, betas=(0.5, 0.9)) scheduler = torch.optim.lr_scheduler.StepLR(optimizer, int(opt.nepochs*0.4), gamma=0.5, last_epoch=-1) if opt.dim == 3: weights_ref = mean_value_coordinates_3D( source_points, cage_init, cage_f, verbose=False) else: raise NotImplementedError for t in range(opt.nepochs): optimizer.zero_grad() weights = mean_value_coordinates_3D( target_points, cage_v, cage_f, verbose=False) loss_mvc = torch.mean((weights-weights_ref)**2) # reg = torch.sum((cage_init-cage_v)**2, dim=-1)*1e-4 reg = 0 if opt.clap_weight > 0: reg = lap_loss(cage_init, cage_v, face=cage_f)*opt.clap_weight reg = reg.mean() if opt.mvc_weight > 0: reg += mvc_reg_loss(weights)*opt.mvc_weight # weight regularizer with the shape difference # dist = torch.sum((source_points - target_points)**2, dim=-1) # weights = torch.exp(-dist) # reg = reg*weights*0.1 loss = loss_mvc + reg if (t+1) % 50 == 0: print("t {}/{} mvc_loss: {} reg: {}".format(t, opt.nepochs, loss_mvc.item(), reg.item())) if loss_mvc.item() < 5e-6: break loss.backward() optimizer.step() scheduler.step() return cage_v, cage_f
def __init__(self, nCam, offset, focalLength, device=None, points=None, normals=None, camWidth=256, camHeight=256, filename="../example_data/pointclouds/sphere_300.ply", closer=True): """ create camera position from a sphere around shape with descreasing distance input: nCam: total number of cameras offset: a number distance to shape surface focalLength: a number (optional) points (B,N,3or4) allPositions (B,C,3) allRotations (B,C,3,3) """ if device is None: if points is not None: self.device = points.device else: self.device = torch.cuda.current_device() else: self.device = device self.closer = closer if filename is not None: self.allPositions = torch.from_numpy(read_ply( filename, nCam)).to(device=self.device)[:, :3] self.allPositions = self.allPositions.unsqueeze(0) else: sampleIdx, self.allPositions = operations.furthest_point_sample( points.cuda(), nCam, NCHW=False) self.allPositions = self.allPositions.to(self.device) if normals is not None: _, idx, _ = operations.faiss_knn(100, self.allPositions.cpu(), points.cpu(), NCHW=False) knn_normals = torch.gather( normals.unsqueeze(1).expand(-1, self.allPositions.shape[1], -1, -1), 2, idx.unsqueeze(-1).expand(-1, -1, -1, normals.shape[-1])) normals = torch.mean(knn_normals, dim=2).to(self.device) if points is not None: if points.dim() == 2: points = points.unsqueeze(0) maxP = torch.max(points, dim=1, keepdim=True)[0] minP = torch.min(points, dim=1, keepdim=True)[0] bb = maxP - minP offset = offset + bb if normals is not None: center = self.allPositions # self.allPositions = (torch.mean(normals, dim=1, keepdim=True)) self.allPositions = normals + (torch.mean( normals, dim=1, keepdim=True)) self.allPositions += torch.randn_like(self.allPositions) * 0.01 else: center = torch.mean(points, dim=1, keepdim=True) else: center = torch.zeros([1, 1, 3], dtype=self.allPositions.dtype, device=self.allPositions.device) self.allPositions = self.allPositions * offset self.allPositions = center + self.allPositions # Bx1x3 self.to = center.expand_as(self.allPositions) # BxNx3 # self.ups = torch.tensor([0, 1, 0], dtype=self.to.dtype, device=self.to.device).view(1, 1, 3).expand_as(self.allPositions) # for sketchfab self.ups = torch.tensor([0, 0, 1], dtype=self.to.dtype, device=self.to.device).view(1, 1, 3).expand_as( self.allPositions) self.ups = self.ups + torch.randn_like(self.ups) * 0.0001 self.rotation, self.position = batchLookAt(self.allPositions, self.to, self.ups) self.idx = 0 self.length = self.rotation.shape[1] self.focalLength = focalLength self.camWidth = camWidth self.camHeight = camHeight