def get_matches(feat_source, feat_target, sym=False): matches = knn(feat_target, feat_source, k=1).T if sym: match_inv = knn(feat_source, feat_target, k=1).T mask = match_inv[matches[:, 1], 1] == torch.arange(matches.shape[0]) return matches[mask] else: return matches
def embedding_correspondences(self, shape_x: Shape, shape_y: Shape, emb_x, emb_y): emb_x = emb_x.to(device_cpu) emb_y = emb_y.to(device_cpu) ass_x = knn(emb_y, emb_x, k=1).to(device) ass_y = knn(emb_x, emb_y, k=1).to(device) ass_y = torch.index_select(ass_y, 0, my_long_tensor([1, 0])) ass_x = ass_x.transpose(0, 1) ass_y = ass_y.transpose(0, 1) self.ass_to_samples(ass_x, ass_y, shape_x, shape_y)
def empirical_estimate(points, num_neighbors): ps, batch = points.permute(1, 0).cat_tensors() N = ps.shape[0] val = knn(ps.contiguous(), ps.contiguous(), batch_x=batch, batch_y=batch, k=num_neighbors) A = ps[val[1, :]].reshape(N, num_neighbors, 3) A = A - A.mean(dim=1, keepdim=True) Asqr = A.permute(0, 2, 1).bmm(A) sigma, _ = Asqr.cpu().symeig() w = (sigma[:, 2] * sigma[:, 1]).sqrt() val = val[1, :].reshape(N, num_neighbors) w, _ = torch.median(w[val].to(ps.device), dim=1, keepdim=True) weights = TensorListList() bi = 0 for point_list in points: ww = TensorList() for p in point_list: ww.append(w[batch == bi]) bi = bi + 1 weights.append(ww) return weights
def forward(self, data, idx): r, pos, batch = data.r, data.pos, data.batch #x, edge_index, batch, edge_attr = data.x.float(), data.edge_index, data.batch, data.edge_attr.float() #r_limit = r[0]*0.5 #row, col = radius(pos, pos[idx], r_limit, batch, batch[idx], max_num_neighbors=64) row, col = knn(pos,pos, 32,batch, batch) #row, col = radius(pos, pos, r_limit, batch, batch, max_num_neighbors=32) edge_index = torch.stack([col, row], dim=0).to(device)# (col, row), or (col row) #edge_attr = torch.ones((edge_index.shape[1],1)).to(device) x1 = F.relu(self.conv1(pos, edge_index)) x2 = F.relu(self.conv2(x1, edge_index)) x = self.lin1(x2) # x = x.view(-1,512,128) # x = torch.transpose(x,1,2) # #x = x[:,:,:64] # x = x.reshape(-1,512) # x = self.condense(x) # x = x.view(-1,128,512) # x = torch.transpose(x,1,2) # x = x.reshape(-1,128) x = gmp(x, batch) x = self.lin2(x) x = self.lin3(x) x = self.output(x) return F.log_softmax(x, dim=-1)
def find_k_neighbor(rep_pts, pts, K, D): """ Find K nearest neighbor points :param pts: represent points(B, P, C) :param rep_pts: original points (B, N, C) :param K: :param D: Dilation rate :return group_pts: K neighbor points(B, P, K, C) """ device = pts.device B, N, C = pts.shape _, N_rep, _ = rep_pts.shape batch_pts = (torch.arange(0, B * N) // N).to(pts.device).contiguous() batch_rep_pts = (torch.arange(0, B * N_rep) // N_rep).to( pts.device).contiguous() pts = pts.view(-1, C).contiguous() rep_pts = rep_pts.view(-1, C).contiguous() knn_indices = knn(pts, rep_pts, K * D, batch_pts, batch_rep_pts) knn_indices = knn_indices[ 1] ## grab the indices at 0th index it is [0,0,0,0,0,1,1,1,1,1,2,2,2,2,...] group_pts = pts[knn_indices].view(B, N_rep, K * D, C).contiguous() rand_col = torch.randint(K * D, (K, )) group_pts = group_pts[:, :, rand_col, :] return group_pts, knn_indices, rand_col
def fps_pooling(pos, x, edge_attr, batch=None, k=16, r=0.5, reduce='sum'): assert reduce in ['max', 'mean', 'add', 'sum'] idx = fps(pos, batch, ratio=r) i, j = knn(pos, pos[idx], k, batch, batch[idx]) x = scatter(x[j], i, dim=0, reduce=reduce) pos, edge_attr, batch = pos[idx], edge_attr[idx], batch[idx] return x, pos, edge_attr, batch
def knn_interpolate(x: torch.Tensor, pos_x: torch.Tensor, pos_y: torch.Tensor, batch_x: OptTensor = None, batch_y: OptTensor = None, k: int = 3, num_workers: int = 1): r"""The k-NN interpolation from the `"PointNet++: Deep Hierarchical Feature Learning on Point Sets in a Metric Space" <https://arxiv.org/abs/1706.02413>`_ paper. For each point :math:`y` with position :math:`\mathbf{p}(y)`, its interpolated features :math:`\mathbf{f}(y)` are given by .. math:: \mathbf{f}(y) = \frac{\sum_{i=1}^k w(x_i) \mathbf{f}(x_i)}{\sum_{i=1}^k w(x_i)} \textrm{, where } w(x_i) = \frac{1}{d(\mathbf{p}(y), \mathbf{p}(x_i))^2} and :math:`\{ x_1, \ldots, x_k \}` denoting the :math:`k` nearest points to :math:`y`. Args: x (Tensor): Node feature matrix :math:`\mathbf{X} \in \mathbb{R}^{N \times F}`. pos_x (Tensor): Node position matrix :math:`\in \mathbb{R}^{N \times d}`. pos_y (Tensor): Upsampled node position matrix :math:`\in \mathbb{R}^{M \times d}`. batch_x (LongTensor, optional): Batch vector :math:`\mathbf{b_x} \in {\{ 0, \ldots, B-1\}}^N`, which assigns each node from :math:`\mathbf{X}` to a specific example. (default: :obj:`None`) batch_y (LongTensor, optional): Batch vector :math:`\mathbf{b_y} \in {\{ 0, \ldots, B-1\}}^N`, which assigns each node from :math:`\mathbf{Y}` to a specific example. (default: :obj:`None`) k (int, optional): Number of neighbors. (default: :obj:`3`) num_workers (int): Number of workers to use for computation. Has no effect in case :obj:`batch_x` or :obj:`batch_y` is not :obj:`None`, or the input lies on the GPU. (default: :obj:`1`) """ with torch.no_grad(): assign_index = knn(pos_x, pos_y, k, batch_x=batch_x, batch_y=batch_y, num_workers=num_workers) y_idx, x_idx = assign_index[0], assign_index[1] diff = pos_x[x_idx] - pos_y[y_idx] squared_distance = (diff * diff).sum(dim=-1, keepdim=True) weights = 1.0 / torch.clamp(squared_distance, min=1e-16) y = scatter_add(x[x_idx] * weights, y_idx, dim=0, dim_size=pos_y.size(0)) y = y / scatter_add(weights, y_idx, dim=0, dim_size=pos_y.size(0)) return y
def edge_loss(p): """ p: BxNx3 """ batch = torch.cat([torch.ones(p.size(1)) * bn for bn in range(p.size(0))]) nearest_idxs = gnn.knn(p.view(-1, 3), p.view(-1, 3), 2, batch, batch)[1, 1::2] edge_len = torch.norm(p.view(-1, 3) - p.view(-1, 3)[nearest_idxs], dim=-1).view(p.size()[:2]) edge_loss = edge_len.mean(dim=-1).mean() return edge_loss
def forward(self, in_layer_data, skip_layer_data): in_x, in_pos, in_batch = in_layer_data skip_x, skip_pos, skip_batch = skip_layer_data row, col = knn(in_pos, skip_pos, self.knn_num, in_batch, skip_batch) edge_index = torch.stack([col, row], dim=0) x1 = self.point_conv((in_x, skip_x), (in_pos, skip_pos), edge_index) pos1, batch1 = skip_pos, skip_batch return x1, pos1, batch1
def forward(self, x, y): target = torch.stack((x, y), dim=1) nodes = torch.vstack((self.points, self.f_points)) a_index = knn(nodes, target, self.max_nbrs) h = self.widths() if self.use_pu: dnr = self.sph.forward(x, y, nodes, h, 1.0, a_index) else: dnr = 1.0 nr = self.sph.forward(x, y, nodes, h, self.u, a_index) return nr / dnr
def knn_graph_3d(self, pos, x, batch, k, direction='source2target'): """Compute k-nearest neighbors in 3D Euclidean space. Targets indicate points and sources indicates their neighbors. Edges are directed either source-to-target or target-to-source. Returns: edge_index: [2, M] (i, j) = edge_index[:, e] indicated j is one of i's nearest neighbor. """ target_idx, source_idx = knn(pos, pos, k, batch, batch) if direction == 'source2target': edge_index = torch.stack([source_idx, target_idx], dim=0) else: edge_index = torch.stack([target_idx, source_idx], dim=0) return edge_index
def forward(self, x, pos, batch): pos6d = torch.cat([pos, x], dim=-1) ratio = (self.num_clusters + 2) / x.shape[0] idx = fps(pos6d, batch, ratio=ratio, random_start=False) idx = idx[:self.num_clusters] edge_index = knn(pos6d[idx], pos6d, 1, batch[idx], batch) x = scatter_mean(x.index_select(index=edge_index[0], dim=0), edge_index[1].unsqueeze(1), dim=0) pos = scatter_mean(pos.index_select(index=edge_index[0], dim=0), edge_index[1].unsqueeze(1), dim=0) return x, pos, batch[idx], edge_index
def forward(self, support_xyz, batch_index, filtered_index, features): # knn query_xyz = torch.index_select(support_xyz, 0, filtered_index) query_batch_index = torch.index_select(batch_index, 0, filtered_index) query_features = torch.index_select(features, 0, filtered_index) row, col = knn(support_xyz, query_xyz, self.k, batch_index, query_batch_index) edge_index = torch.stack([col, row], dim=0) # x has shape [N, F_in] # edge_index has shape [2, E] return self.propagate(edge_index, x=(features, query_features), pos=(support_xyz, query_xyz)) # shape [N, F_out]
def __call__(self, sample): keys = sorted([x for x in dir(sample) if 'edge_index' in x]) num_vertices = [sample.num_nodes] + sample.num_vertices # sample.edge_index = knn_graph(sample.pos, self._k[0]) # knn_edges = knn_graph(sample.pos, self._k[0] * self._d) # dilated_idx = [index for index in range(knn_edges.shape[1])[0::self._d]] # sample.edge_index = knn_edges[:, dilated_idx] for level, key in enumerate(keys): if level == len(keys) - 1: break pos_key = key.replace('edge_index', 'pos').replace('hierarchy_', '') subset_points_idx = fps(sample[pos_key], ratio=sample.num_vertices[level] / num_vertices[level]) # if level == 0: # sample.y = sample.y[subset_points_idx] num_vertices[level + 1] = subset_points_idx.shape[0] sample.num_vertices[level] = num_vertices[level + 1] sample['pos_' + str(level + 1)] = sample[pos_key][subset_points_idx] sample[f"hierarchy_trace_index_{level+1}"] = knn( sample['pos_' + str(level + 1)], sample[pos_key], 1)[1, :] # sample[f"hierarchy_edge_index_{level+1}"] = knn_graph(sample['pos_' + str(level+1)], self._k[level+1]) # knn_edges = knn_graph(sample['pos_' + str(level+1)], self._k[level+1] * self._d) # dilated_idx = [index for index in range(knn_edges.shape[1])[0::self._d]] # sample[f"hierarchy_edge_index_{level+1}"] = knn_edges[:, dilated_idx] keys = sorted([x for x in dir(sample) if x.startswith('x_')]) for key in keys: delattr(sample, key) # keys = sorted([x for x in dir(sample) if 'pos_' in x]) # for key in keys: # delattr(sample, key) return sample
def knn_interpolate(x, pos_x, pos_y, batch_x=None, batch_y=None, k=3): r"""The k-NN interpolation from the `"PointNet++: Deep Hierarchical Feature Learning on Point Sets in a Metric Space" <https://arxiv.org/abs/1706.02413>`_ paper. For each point :obj:`y`, its interpolated features are given by .. math:: f(\mathbf{x}) = \frac{\sum_{i=1}^k w_i(y) f_i}{\sum_{i=1}^k w_i(y)} \textrm{, where } w_i(y) = \frac{1}{d(y, x_i)^2}. Args: x (Tensor): Node feature matrix :math:`\mathbf{X} \in \mathbb{R}^{N \times F}`. pos_x (Tensor): Node position matrix :math:`\in \mathbb{R}^{N \times d}`. pos_y (Tensor): Upsampled node position matrix :math:`\in \mathbb{R}^{M \times d}`. batch_x (LongTensor, optional): Batch vector :math:`\mathbf{b_x} \in {\{ 0, \ldots, B-1\}}^N`, which assigns each node from :math:`\mathbf{X}` to a specific example. (default: :obj:`None`) batch_y (LongTensor, optional): Batch vector :math:`\mathbf{b_y} \in {\{ 0, \ldots, B-1\}}^N`, which assigns each node from :math:`\mathbf{Y}` to a specific example. (default: :obj:`None`) k (int, optional): Number of neighbors. (default: :obj:`3`) :rtype: :class:`Tensor` """ with torch.no_grad(): y_idx, x_idx = knn(pos_x, pos_y, k, batch_x=batch_x, batch_y=batch_y) diff = pos_x[x_idx] - pos_y[y_idx] squared_distance = (diff * diff).sum(dim=-1, keepdim=True) print(squared_distance) weights = 1.0 / torch.clamp(squared_distance, min=1e-16) # print(weights) print(x) y = scatter_add(x[x_idx] * weights, y_idx, dim=0, dim_size=pos_y.size(0)) y = y / scatter_add(weights, y_idx, dim=0, dim_size=pos_y.size(0)) return y
def extract_correspondences_gpu(coords, Rs, ts, device=None): coords = Rs @ coords + ts M = len(coords) ind_nns = [] dists = [] if not device is None: coords = coords.to(device) for ind1 in range(M - 1): for ind2 in range(ind1 + 1, M): point1 = coords[ind1].permute(1, 0) point2 = coords[ind2].permute(1, 0) inds_nn = knn(point2, point1, 1) d = point1[inds_nn[0, :]] - point2[inds_nn[1, :]] d = (d * d).sum(dim=1).sqrt() dists.append(d) ind_nns.append(inds_nn) return dict(indices=TensorList(ind_nns), distances=TensorList(dists))
def build_bipartite_graph(self, pos, batch, ratio, method='radius', r=0.1, k=32, dilation=1): ''' Build a bipartite graph given a pos vector :param pos: (torch.Tensor) The position of the points :param batch: (torch.Tensor) The batch index for each point :param ratio: (float) The sample ratio :param method: (str) Specify which method to use, ['radius', 'knn'] :param r: (float) If 'radius' is adopted, the radius of the support domain :param k: (int) IF 'knn' is adopted, the number of neighbors :param dilation: (int) If 'knn' is adopted, the dilation rate :return: (torch.Tensor) The edge index, position and batch of the sampled points ''' assert method in ['radius', 'knn'] idx = fps(pos, batch, ratio=ratio) if method == 'radius': row, col = radius(pos, pos[idx], r, batch, batch[idx], max_num_neighbors=k) if method == 'knn': row, col = knn(pos, pos[idx], k * dilation, batch, batch[idx]) if dilation > 1: n = idx.shape[0] index = torch.randint(k * dilation, (n, k), dtype=torch.long, device=row.device) arange = torch.arange(n, dtype=torch.long, device=row.device) arange = arange * (k * dilation) index = (index + arange.view(-1, 1)).view(-1) row, col = row[index], col[index] edge_index = torch.stack([col, row], dim=0) return edge_index, pos[idx], batch[idx]
def init_upscale(self, num_knn=3): self.ass_array = [] self.ass_vecs = [] self.ass_weights = [] for idx in range(self.scale_idx_len - 1): vert_i = self.vert_array[idx].to(device_cpu) vert_ip1 = self.vert_array[idx + 1].to(device_cpu) ass_curr = knn(vert_i, vert_ip1, num_knn) ass_curr = ass_curr[1, :].view(-1, num_knn) self.ass_array.append(ass_curr.to(device)) #[n_vert_tp1, num_knn] vec_curr = vert_ip1.unsqueeze(1) - vert_i[ass_curr, :] self.ass_vecs.append( vec_curr.to(device)) #[n_vert_tp1, num_knn, 3] weights_curr = 1 / (torch.norm(vec_curr, dim=2, keepdim=True) + 1e-5) weights_curr = weights_curr / torch.sum( weights_curr, dim=1, keepdim=True) self.ass_weights.append( weights_curr.to(device)) #[n_vert_tp1, num_knn, 1]
def process_ply( ply_path: str, n_patch: int = 100, k: int = 2048, down_sample: Union[None, int, float] = None, cuda: bool = False, ): r""" Processes PLY file from path, convert to PC :return patch-pos + patch-color + patch-kernel """ mesh_data = read_mesh(ply_path) print(colorama.Fore.GREEN + "Loaded PLY file %s" % ply_path) pos, color = mesh_data.pos, mesh_data.color if cuda: pos = pos.to(torch.device("cuda:7")) color = color.to(pos) # down sampling using uniform method n_pts = pos.shape[0] if down_sample is not None: if isinstance(down_sample, int): idx = torch.randperm(n_pts)[:down_sample] pos, color = pos[idx], color[idx] elif isinstance(down_sample, float): idx = torch.randperm(n_pts)[:np.floor(down_sample * n_pts)] pos, color = pos[idx], color[idx] # select N kernels by FPS sample: [N, ] patch_kernel = fps(pos, ratio=n_patch / n_pts, random_start=True) # sprint(patch_kernel.shape) # select N patches: [N, 2048] patches = knn(pos, pos[patch_kernel], k=k, num_workers=32) patches = patches[1].reshape(-1, k) # sprint(patches.shape) data_list = [ # ([N, 2048, 3], [N, 2048, 3], [N, 3]) Data(color=color[patch], pos=pos[patch], patch_index=patch) for patch, pk in zip(patches, patch_kernel) ] return data_list
def precompute(self, query, support): """ Precomputes a data structure that can be used in the transform itself to speed things up """ pos_x, pos_y = query.pos, support.pos if hasattr(support, "batch"): batch_y = support.batch else: batch_y = torch.zeros((support.num_nodes,), dtype=torch.long) if hasattr(query, "batch"): batch_x = query.batch else: batch_x = torch.zeros((query.num_nodes,), dtype=torch.long) with torch.no_grad(): assign_index = knn(pos_x, pos_y, self.k, batch_x=batch_x, batch_y=batch_y) y_idx, x_idx = assign_index diff = pos_x[x_idx] - pos_y[y_idx] squared_distance = (diff * diff).sum(dim=-1, keepdim=True) weights = 1.0 / torch.clamp(squared_distance, min=1e-16) normalisation = scatter_add(weights, y_idx, dim=0, dim_size=pos_y.size(0)) return Data(num_nodes=support.num_nodes, x_idx=x_idx, y_idx=y_idx, weights=weights, normalisation=normalisation)
def find_neighbours(self, x, y, batch_x, batch_y): return knn(x, y, self.k, batch_x, batch_y)
def find_neighbours(self, x, y, batch_x, batch_y): tmp = knn(x, y, self.k, batch_x, batch_y) return tmp
def forward(self, data, final, start): x, edge_index, pos, batch = data.x, data.edge_index, data.pos, data.batch x_start = x # Only street based pooling if self.clustering == 'Street': batchClusters1 = self.clusters1 batchCat = self.categories batchClusters2 = self.clusters2 batch_size = torch.max(batch) + 1 # Divide clusters and categories from different batches for i in range(1, batch_size): batchClusters1 = torch.cat( (batchClusters1, self.clusters1 + i * self.maxCluster1)) batchCat = torch.cat((batchCat, self.categories + i * 5)) batchClusters2 = torch.cat( (batchClusters2, self.clusters2 + i * self.maxCluster2)) batchCat = batchCat.long() data.batch = batchCat data2 = data # Both pooled branches, max pooling data = max_pool(batchClusters1, data) x_t, edge_index_t, pos_t, batchCat_t = data.x, data.edge_index, data.pos, data.batch data2 = max_pool(batchClusters2, data2) x_t2, edge_index_t2, pos_t2, batchCat_t2 = data2.x, data2.edge_index, data2.pos, data2.batch edge_index_t, temp = add_self_loops(edge_index_t) edge_index_t2, temp = add_self_loops(edge_index_t2) # Add coordinates and categories to input if self.coords: cats = (batchCat % 5).float() catsT = (batchCat_t % 5).float() catsT2 = (batchCat_t2 % 5).float() normPos = pos / torch.max(pos) normPos_t = pos_t / torch.max(pos_t) normPos_t2 = pos_t2 / torch.max(pos_t2) normCat = (cats / 4).view(batchCat.size(0), 1) normCat_t = (catsT / 4).view(batchCat_t.size(0), 1) normCat_t2 = (catsT2 / 4).view(batchCat_t2.size(0), 1) x = torch.cat((x, normPos, normCat), 1) x_t = torch.cat((x_t, normPos_t, normCat_t), 1) x_t2 = torch.cat((x_t2, normPos_t2, normCat_t2), 1) # Perform convolution blocks in all 3 branches for i in range(self.layers): x_temp = x x = self.moduleList1[i](x, edge_index) if self.midSkip: x = torch.cat((x, x_temp), 1) if i == 0: bn = self.bn elif i == 1: bn = self.bn2 else: bn = self.bn3 x = F.relu(bn(self.skipList1[i](x, edge_index))) for i in range(self.layers): x_ttemp = x_t x_t = self.moduleList2[i](x_t, edge_index_t) if self.midSkip: x_t = torch.cat((x_t, x_ttemp), 1) if i == 0: bn = self.bn elif i == 1: bn = self.bn2 else: bn = self.bn3 x_t = F.relu(bn(self.skipList2[i](x_t, edge_index_t))) for i in range(self.layers): x_ttemp2 = x_t2 x_t2 = self.moduleList3[i](x_t2, edge_index_t2) if self.midSkip: x_t2 = torch.cat((x_t2, x_ttemp2), 1) if i == 0: bn = self.bn elif i == 1: bn = self.bn2 else: bn = self.bn3 x_t2 = F.relu(bn(self.skipList3[i](x_t2, edge_index_t2))) # Calculate knn weights of both pooled branches for first batch (and last, since the size might be different) if start: sorter = torch.argsort(batchCat) backsorter = torch.argsort(sorter) pos = pos[sorter] batchCat = batchCat[sorter] pairs = knn(pos_t, pos, self.knn, batch_x=batchCat_t, batch_y=batchCat) yIdx, xIdx = pairs diff = pos_t[xIdx] - pos[yIdx] squared_distance = (diff * diff).sum(dim=-1, keepdim=True) weights = 1.0 / torch.clamp(squared_distance, min=1e-16) pairs2 = knn(pos_t2, pos, self.knn, batch_x=batchCat_t2, batch_y=batchCat) yIdx2, xIdx2 = pairs2 diff2 = pos_t2[xIdx2] - pos[yIdx2] squared_distance2 = (diff2 * diff2).sum(dim=-1, keepdim=True) weights2 = 1.0 / torch.clamp(squared_distance2, min=1e-16) self.weights = weights self.xIdx = xIdx self.yIdx = yIdx self.weights2 = weights2 self.xIdx2 = xIdx2 self.yIdx2 = yIdx2 self.backSorter = backsorter if final: sorter = torch.argsort(batchCat) backsorter = torch.argsort(sorter) pos = pos[sorter] batchCat = batchCat[sorter] pairs = knn(pos_t, pos, self.knn, batch_x=batchCat_t, batch_y=batchCat) yIdx, xIdx = pairs diff = pos_t[xIdx] - pos[yIdx] squared_distance = (diff * diff).sum(dim=-1, keepdim=True) weights = 1.0 / torch.clamp(squared_distance, min=1e-16) pairs2 = knn(pos_t2, pos, self.knn, batch_x=batchCat_t2, batch_y=batchCat) yIdx2, xIdx2 = pairs2 diff2 = pos_t2[xIdx2] - pos[yIdx2] squared_distance2 = (diff2 * diff2).sum(dim=-1, keepdim=True) weights2 = 1.0 / torch.clamp(squared_distance2, min=1e-16) self.weights = weights self.xIdx = xIdx self.yIdx = yIdx self.weights2 = weights2 self.xIdx2 = xIdx2 self.yIdx2 = yIdx2 self.backSorter = backsorter # Unpool pooled branches x_t = scatter_add(x_t[self.xIdx] * self.weights, self.yIdx, dim=0, dim_size=pos.size(0)) x_t = x_t / scatter_add( self.weights, self.yIdx, dim=0, dim_size=pos.size(0)) x_t = x_t[self.backSorter] x_t2 = scatter_add(x_t2[self.xIdx2] * self.weights2, self.yIdx2, dim=0, dim_size=pos.size(0)) x_t2 = x_t2 / scatter_add( self.weights2, self.yIdx2, dim=0, dim_size=pos.size(0)) x_t2 = x_t2[self.backSorter] # Input size of final convolution if self.skipconv: y = torch.cat((x, x_t, x_t2, x_start), 1) else: y = torch.cat((x, x_t, x_t2), 1) # Do final convolution y = self.conv_mix(y, edge_index) # Add dropout layer if self.p != 1: y = F.dropout(y, training=self.training, p=self.p) return y
def fps_max_pooling(pos, x, batch=None, k=16, r=0.5): idx = fps(pos, batch, ratio=r) i, j = knn(pos, pos[idx], k, batch, batch[idx]) x = scatter(x[j], i, dim=0, reduce='max') pos, batch = pos[idx], batch[idx] return x, pos, batch
def losses(pred, target, target_norms, device, chamfer=True, edge=False, norm=False, t=2 / 32, metrics=True): # form vectors of batch indicators pred_batch = torch.ones(pred.size()[:2], dtype=torch.int64, device=device) pred_batch *= torch.arange(start=0, end=pred.size(0), dtype=torch.int64, device=device).view(-1, 1) pred_batch = pred_batch.flatten() target_batch = torch.ones(target.size()[:2], dtype=torch.int64, device=device) target_batch *= torch.arange(start=0, end=target.size(0), dtype=torch.int64, device=device).view(-1, 1) target_batch = target_batch.flatten() chamfer_loss = edge_loss = norm_loss = 0. precision = recall = fscore = 0. # Chamfer Loss if chamfer or norm: target_pred_nearest = gnn.knn(x=pred.view(-1, 3), y=target.view(-1, 3), k=1, batch_x=pred_batch, batch_y=target_batch) target_pred_nearest = target_pred_nearest.view(2, target.size(0), target.size(1), -1) pred_target_nearest = gnn.knn(x=target.view(-1, 3), y=pred.view(-1, 3), k=1, batch_x=target_batch, batch_y=pred_batch) pred_target_nearest = pred_target_nearest.view(2, pred.size(0), pred.size(1), -1) if chamfer: pred_target_dist = torch.norm( pred.view(-1, 3)[pred_target_nearest[0].view(-1)] - target.view(-1, 3)[pred_target_nearest[1].view(-1)], dim=-1) pred_target_dist = pred_target_dist.view(pred.size(0), pred.size(1)) target_pred_dist = torch.norm( target.view(-1, 3)[target_pred_nearest[0].view(-1)] - pred.view(-1, 3)[target_pred_nearest[1].view(-1)], dim=-1) target_pred_dist = target_pred_dist.view(target.size(0), target.size(1)) target_pred_dist_mean = target_pred_dist.mean(dim=-1).mean() pred_target_dist_mean = pred_target_dist.mean(dim=-1).mean() chamfer_loss = pred_target_dist_mean + target_pred_dist_mean if edge or norm: pred_k_3 = gnn.knn_graph(x=pred.view(-1, 3), k=3, batch=pred_batch) # Edge Loss if edge: pred_edge_dist = torch.norm(pred.view(-1, 3)[pred_k_3[0, 0::3]] - pred.view(-1, 3)[pred_k_3[1, 0::3]], dim=-1) pred_edge_dist = pred_edge_dist.view(pred.size(0), pred.size(1)) edge_loss = pred_edge_dist.mean(dim=-1).mean() # Norm Loss if norm: pred_vec1 = pred.view(-1, 3)[pred_k_3[1, 1::3]] - pred.view( -1, 3)[pred_k_3[1, 0::3]] pred_vec2 = pred.view(-1, 3)[pred_k_3[1, 2::3]] - pred.view( -1, 3)[pred_k_3[1, 0::3]] # compute normal vector pred_approx_norm = torch.cross(pred_vec1, pred_vec2) pred_approx_norm = pred_approx_norm / torch.norm( pred_approx_norm, dim=-1).view(-1, 1) # cosine dissimilarity of norms pred_target_norm_dissim = 1 - torch.mul( pred_approx_norm[pred_target_nearest[0].view(-1)], target_norms.view( -1, 3)[pred_target_nearest[1].view(-1)]).sum(dim=-1).abs() pred_target_norm_dissim = pred_target_norm_dissim.view( pred.size(0), pred.size(1)) target_pred_norm_dissim = 1 - torch.mul( target_norms.view(-1, 3)[target_pred_nearest[0].view(-1)], pred_approx_norm[target_pred_nearest[1].view(-1)]).sum( dim=-1).abs() target_pred_norm_dissim = target_pred_norm_dissim.view( target.size(0), target.size(1)) target_pred_norm_dissim_mean = target_pred_norm_dissim.mean( dim=-1).mean() pred_target_norm_dissim_mean = pred_target_norm_dissim.mean( dim=-1).mean() norm_loss = pred_target_norm_dissim_mean + target_pred_norm_dissim_mean if metrics: # compute metrics precision = (target_pred_dist <= t).type( torch.float).mean(dim=-1).mean() recall = (pred_target_dist <= t).type(torch.float).mean(dim=-1).mean() fscore = ((2 * precision * recall) / (precision + recall)).mean() # fix nan. values precision[precision != precision] = 0. recall[recall != recall] = 0. fscore[fscore != fscore] = 0. losses_all = {'cd': chamfer_loss, 'el': edge_loss, 'nl': norm_loss} metrics_all = {'precision': precision, 'recall': recall, 'fscore': fscore} return losses_all, metrics_all