def viewFromError(nCam, gtImage, predImage, predPoints, projPoints, splatter, offset=None): allPositions = torch.from_numpy(read_ply("example_data/pointclouds/sphere_300.ply", nCam)).to(device=splatter.camera.device)[:, :3].unsqueeze(0) device = splatter.camera.device focalLength = splatter.camera.focalLength width = splatter.camera.width height = splatter.camera.height sv = splatter.camera.sv offset = offset or splatter.camera.focalLength*0.5 fromP = allPositions * offset diff = torch.sum((gtImage - predImage).abs(), dim=-1) diff = torch.nn.functional.avg_pool2d(diff.unsqueeze(0), 9, stride=4, padding=4, ceil_mode=False, count_include_pad=False).squeeze(0) w = diff.argmax() % diff.shape[0] h = diff.argmax() // diff.shape[0] w *= 4 h *= 4 # average points projected inside this region _, knn_idx, _ = operations.group_knn(5, torch.tensor([w, h, 1], dtype=projPoints.dtype, device=projPoints.device).view(1, 1, 3).expand(projPoints.shape[0], -1, -1), projPoints, unique=False, NCHW=False) # B, 1, K PN = predPoints.shape[0] knn_points = torch.gather(predPoints.unsqueeze(1).expand(-1, PN, -1, -1), 2, knn_idx.unsqueeze(-1).expand(-1, -1, -1, predPoints.shape[-1])) center = torch.mean(knn_points, dim=-2).to(device=device) ups = torch.tensor([0, 0, 1], dtype=center.dtype, device=device).view(1, 1, 3).expand_as(fromP) ups = ups + torch.randn_like(ups) * 0.0001 rotation, position = batchLookAt(fromP, center, ups) cameras = [] for i in range(nCam): cam = PinholeCamera(device=device, focalLength=focalLength, width=width, height=height, sv=sv) cam.rotation = rotation[:, i, :, :] cam.position = position[:, i, :] cameras.append(cam) return diff.max(), cameras
def _computeDensity(points, knn_k=33, radius=0.1): radius2 = radius*radius if points.is_cuda: knn_points, knn_idx, distance2 = operations.group_knn(knn_k, points, points, unique=False, NCHW=False) else: knn_points, knn_idx, distance2 = operations.faiss_knn(knn_k, points, points, NCHW=False) knn_points = knn_points[:, :, 1:, :].contiguous().detach() knn_idx = knn_idx[:, :, 1:].contiguous() distance2 = distance2[:, :, 1:].contiguous() # ball query find center knn_points = torch.where(distance2.unsqueeze(-1)>radius2, torch.zeros_like(knn_points), knn_points) weight = torch.exp(-distance2/radius2/4) weight = torch.where(distance2>radius2, torch.zeros_like(weight), weight) weight = torch.sum(weight, dim=-1) return weight
def applyAverageTerm(self, points_data, normals_data, original_points, idxList=None, original_density=None): """ points B,N,3 original_points B,N,3 original_density B,N,1 """ points = points_data if idxList is not None: points = torch.gather(points_data, 1, idxList.expand(-1, -1, points_data.shape[-1])) normals = torch.gather(normals_data, 1, idxList.expand(-1, -1, normals_data.shape[-1])) PN = points.shape[1] knn_k = 16 if points.is_cuda: knn_points, knn_idx, distance2 = operations.group_knn(knn_k, points, original_points, unique=False, NCHW=False) else: knn_points, knn_idx, distance2 = operations.faiss_knn(knn_k, points, original_points, NCHW=False) radius2 = self.repulsion_radius*self.repulsion_radius # ball query find center knn_points = torch.where(distance2.unsqueeze(-1)>radius2, torch.zeros_like(knn_points), knn_points) weight = torch.exp(-distance2/radius2/4) weight = torch.where(distance2>radius2, torch.zeros_like(weight), weight) # original density term if original_density is not None: if original_density.dim() == 3: original_density = original_density.squeeze(-1) original_density_weight = torch.gather(original_density.unsqueeze(1).expand(-1, PN, -1), 2, knn_idx) original_density_weight = torch.where(distance2>radius2, torch.zeros_like(original_density_weight), original_density_weight) weight = weight * original_density_weight weightSum = torch.sum(weight, dim=-1, keepdim=True) + 1e-8 weight /= weightSum # find average originalAverage = torch.sum(knn_points * weight.unsqueeze(-1), dim=-2) # project to its normal update = dot(originalAverage - points, normals, dim=-1).unsqueeze(-1) * normals * self.average_weight if idxList is not None: points_data.scatter_add_(1, idxList.expand(-1, -1, points_data.shape[-1]), update) return points += update
def forward(self, xyz1, xyz2, **kwargs): xyz1 = xyz1.contiguous() xyz2 = xyz2.contiguous() B, N, C = xyz1.shape grouped_points, idx, _ = group_knn(self.nn_size, xyz1, xyz1, unique=True, NCHW=False) group_center = torch.mean(grouped_points, dim=2, keepdim=True) grouped_points = grouped_points - group_center # fit pca allpoints = grouped_points.view(-1, self.nn_size, C).contiguous() # BN,C,k U, S, V = batch_svd(allpoints) # V is BNxCxC, last_u BNxC normals = V[:, :, -1].view(B, N, C).detach() # FIXME what about the sign of normal ptof1 = dot_product((xyz1 - group_center.squeeze(2)), normals, dim=-1) # for xyz2 use the same neighborhood grouped_points = torch.gather( xyz2.unsqueeze(1).expand(-1, N, -1, -1), 2, idx.unsqueeze(-1).expand(-1, -1, -1, C)) group_center = torch.mean(grouped_points, dim=2, keepdim=True) grouped_points = grouped_points - group_center allpoints = grouped_points.view(-1, self.nn_size, C).contiguous() # MB,C,k U, S, V = batch_svd(allpoints) # V is MBxCxC, last_u MBxC normals = V[:, :, -1].view(B, N, C).detach() ptof2 = dot_product((xyz2 - group_center.squeeze(2)), normals, dim=-1) # compare ptof1 and ptof2 absolute value (absolute value can only determine bent, not direction of bent) loss = self.metric(ptof1.abs(), ptof2.abs()) # # penalize flat->curve bent = ptof2 - ptof1 bent.masked_fill_(bent < 0, 0.0) bent = self.metric(bent, torch.zeros_like(bent)) # bent.masked_fill_(bent<=1.0, 0.0) loss += 5 * bent return loss
def applyProjection(self, points_data, normals_data, nonvisibility_data, idxList=None, decay=1.0): if self.projection_weight <= 0: return batchSize, PN, _ = points_data.shape if PN <= 3: return normals_data = normalize(normals_data) knn_k = 33 sharpness_sigma = self.sharpness_sigma projection_weight = self.projection_weight rradius = self.projection_radius points = points_data normals = normals_data nonvisibility_data = nonvisibility_data.to(device=points.device) nonvisibility = nonvisibility_data if idxList is not None: points = torch.gather(points, 1, idxList.expand(-1, -1, points.shape[-1])) normals = torch.gather(normals, 1, idxList.expand(-1, -1, normals.shape[-1])) nonvisibility = torch.gather(nonvisibility, 1, idxList.expand(-1, -1, nonvisibility.shape[-1])) PN = points.shape[1] rradius2 = rradius**2 iradius = 1/(rradius2)/4 # first KNN (B, N, k, c) if points.is_cuda: knn_points, knn_idx, distance2 = operations.group_knn(knn_k, points, points_data, unique=False, NCHW=False) else: knn_points, knn_idx, distance2 = operations.faiss_knn(knn_k, points, points_data, NCHW=False) # distance2 = distance2 * distance2 knn_points = knn_points[:, :, 1:, :].contiguous() knn_idx = knn_idx[:, :, 1:].contiguous() distance2 = distance2[:, :, 1:].contiguous() if torch.all(distance2[:, :, 0] > rradius2): return knn_normals = torch.gather(normals_data.unsqueeze(1).expand(-1, PN, -1, -1), 2, knn_idx.unsqueeze(-1).expand(-1, -1, -1, normals.shape[-1])) # give invisible points a small weight phi = torch.gather(nonvisibility_data.unsqueeze(1).expand(-1, PN, -1, -1), 2, knn_idx.unsqueeze(-1)).squeeze(-1) # phi = torch.where(phi > 0, torch.full([1, 1, 1], 1.0), torch.full([1, 1, 1], 1.0)) phi = 1 / (1+phi)**2 # B, N, k theta = torch.exp(-distance2*iradius) # B, N, k sharpness_bandwidth = max(1e-5, 1-np.cos(sharpness_sigma*180.0/3.1415926, dtype=np.float32)) sharpness_bandwidth *= sharpness_bandwidth # B, N, k psi = torch.exp(-torch.pow(1-torch.sum(normals.unsqueeze(2)*knn_normals, dim=-1), 2)/sharpness_bandwidth) weight = psi * theta * phi weight = torch.where(distance2 > rradius2, torch.zeros_like(weight), weight) # B, N, k, dot product project_dist_sum = torch.sum((points.unsqueeze(2) - knn_points)*knn_normals, dim=-1)*weight # B, N, 1 project_dist_sum = torch.sum(project_dist_sum, dim=-1, keepdim=True)+1e-10 # B, N, 1 project_weight_sum = torch.sum(weight, dim=-1, keepdim=True)+1e-10 # B, N, c normal_sum = torch.sum(knn_normals*weight.unsqueeze(-1), dim=2) update_normal = normal_sum/project_weight_sum update_normal = normalize(update_normal) # too few neighbors or project_weight_sum too small update_normal = torch.where((torch.sum(distance2 <= rradius2, dim=-1) < 3).unsqueeze(-1) | (project_weight_sum < 1e-7), torch.zeros_like(update_normal), update_normal) point_update = -(update_normal * (project_dist_sum / project_weight_sum)) point_update *= (self.projection_weight*decay) point_update = torch.clamp(point_update, -0.02, 0.02) if not _check_values(point_update): import pdb; pdb.set_trace() # apply this force if idxList is not None: points_data.scatter_add_(1, idxList.expand(-1, -1, points_data.shape[-1]), point_update) return points_data += point_update if self.verbose: saved_variables["projection"] = point_update.cpu() saved_variables["pweight"] = weight.cpu()
def pointRegularizerLoss(self, points_data, normals_data, nonvisibility_data, idxList=None, include_projection=False, use_density=False): if self.repulsion_weight <= 0 and self.projection_weight <= 0: return batchSize, PN, _ = points_data.shape if PN <= 3: return knn_k = 33 normals_data = normalize(normals_data) points = points_data normals = normals_data nonvisibility_data = nonvisibility_data.to(device=points.device) nonvisibility = nonvisibility_data if idxList is not None: points = torch.gather(points, 1, idxList.expand(-1, -1, points.shape[-1])) nonvisibility = torch.gather(nonvisibility, 1, idxList.expand(-1, -1, nonvisibility.shape[-1])) normals = torch.gather(normals, 1, idxList.expand(-1, -1, normals.shape[-1])) PN = points.shape[1] rradius = self.repulsion_radius rradius2 = rradius**2 # repulsion force to projPoints/cameraPoints iradius = 1/(rradius2)/2 # first KNN (B, N, k, c) if points.is_cuda: knn_points, knn_idx, distance2 = operations.group_knn(knn_k, points, points_data, unique=False, NCHW=False) else: knn_points, knn_idx, distance2 = operations.faiss_knn(knn_k, points, points_data, NCHW=False) # distance2 = distance2 * distance2 knn_points = knn_points[:, :, 1:, :].contiguous().detach() knn_idx = knn_idx[:, :, 1:].contiguous() distance2 = distance2[:, :, 1:].contiguous() knn_normals = torch.gather(normals_data.unsqueeze(1).expand(-1, PN, -1, -1), 2, knn_idx.unsqueeze(-1).expand(-1, -1, -1, normals.shape[-1])) knn_v = knn_points - points.unsqueeze(dim=2) # phi, psi and theta are used for finding local plane # while only psi is used for repulsion loss weight # B, N, k phi = torch.gather(nonvisibility_data.unsqueeze(1).expand(-1, PN, -1, -1), 2, knn_idx.unsqueeze(-1)).squeeze(-1) # visibility = 1 / (nonvisibility+1) phi = 1/(phi+1) # # quantize phi, either 1 or 0.1 # phi = torch.where(phi > 0, torch.full([1, 1, 1], 1.0, device=phi.device), torch.full([1, 1, 1], 0.5, device=phi.device)) psi = torch.exp(-distance2*iradius) sharpness_bandwidth = max(1e-5, 1-np.cos(self.sharpness_sigma*180.0/3.1415926, dtype=np.float32)) sharpness_bandwidth *= sharpness_bandwidth # B, N, k theta = torch.exp(-torch.pow(1-torch.sum(normals.unsqueeze(2)*knn_normals, dim=-1), 2)/sharpness_bandwidth) weight = phi*psi*theta weightSum = torch.sum(weight, dim=2, keepdim=True) weight /= (weightSum+1e-10) # project to local plane var = weight.unsqueeze(-1)*(knn_points - torch.sum(weight.unsqueeze(-1)*knn_points, dim=2, keepdim=True)) # the previous step introduces small numeric error due to weighting var = torch.where(var.abs() / torch.max(var.abs(), dim=-1, keepdim=True)[0] < 1e-2, torch.zeros_like(var), var) # BN, k, 3 _, _, V = operations.batch_svd(var.view(-1, knn_k-1, 3)) V = V.detach() totalLoss = 0 ploss = 0 rloss = 0 if include_projection and self.projection_weight > 0: # projection minimize distance to the plane Vp = V.clone() # BN, k, 3, 1 Vn = Vp.unsqueeze(1)[:, :, :, 2:3] # BN, k, 3 knn_v_p = knn_v.clone() # x@V@Vt projection_v = torch.matmul(torch.matmul(knn_v_p.view(-1, knn_k-1, 1, 3), Vn), Vn.transpose(-2,-1)).squeeze(-2) # BN, k distance2 = torch.sum(projection_v*projection_v, dim=-1) # B,N,k distance2 = distance2.view(batchSize, -1, knn_k-1) # weight with visibility and angular, distance similarity ploss = torch.mean(distance2*weight.detach())*self.projection_weight loss = torch.where(distance2 > rradius2, torch.zeros_like(ploss), ploss) totalLoss += ploss if self.repulsion_weight > 0: # repulsion proj to the first two principle axes, set last column of V to zero # BN, 3, 3 V[:, :, -1] = 0 # BN, k, 1, 3 V = V.unsqueeze(-3).expand(-1, knn_k-1, -1, -1) # BN, k, 3 knn_v_r = knn_v.clone() knn_v_r.register_hook(lambda x: x.clamp(-0.02, 0.02)) # BN, k, 3 repulsion_v = torch.matmul(torch.matmul(knn_v_r.view(-1, knn_k-1, 1, 3), V), V.transpose(-2, -1)).squeeze(-2) # repulsion_v = knn_v_r # BN, k distance2 = torch.sum(repulsion_v * repulsion_v, dim=-1) distance2 = distance2.view(batchSize, -1, knn_k-1) # loss = torch.exp(-distance2*iradius) rloss = 1/torch.sqrt(distance2+1e-4) # loss = -distance2 # loss = 1/(distance2+0.001) # (torch.sqrt(distance2+1e-8) - self.repulsion_radius)**2 rloss = torch.where(distance2 > rradius2, torch.zeros_like(rloss), rloss) # B,N,k weight = torch.where(distance2 > rradius2, torch.zeros_like(psi), psi) if use_density: densityWeights = _computeDensity(points) weight = weight * densityWeights.unsqueeze(-1) weightSum = torch.sum(weight, dim=-1, keepdim=True)+1e-8 rloss = rloss * weight.detach() # B,N rloss /= weightSum rloss = torch.mean(rloss)*self.repulsion_weight totalLoss += rloss if include_projection: return ploss, rloss return totalLoss