def knn_graphE(x, k, istrain=False): """Transforms the given point set to a directed graph, whose coordinates are given as a matrix. The predecessors of each point are its k-nearest neighbors. If a 3D tensor is given instead, then each row would be transformed into a separate graph. The graphs will be unioned. Parameters ---------- x : Tensor The input tensor. If 2D, each row of ``x`` corresponds to a node. If 3D, a k-NN graph would be constructed for each row. Then the graphs are unioned. k : int The number of neighbors Returns ------- DGLGraph The graph. The node IDs are in the same order as ``x``. """ if F.ndim(x) == 2: x = F.unsqueeze(x, 0) n_samples, n_points, _ = F.shape(x) dist = pairwise_squared_distance(x) if istrain and np.random.rand() > 0.5: k_indices = F.argtopk(dist, round(1.5 * k), 2, descending=False) rand_k = np.random.permutation(round(1.5 * k) - 1)[0:k - 1] + 1 # 0 + random k-1 rand_k = np.append(rand_k, 0) k_indices = k_indices[:, :, rand_k] # add 0 else: k_indices = F.argtopk(dist, k, 2, descending=False) dst = F.copy_to(k_indices, F.cpu()) src = F.zeros_like(dst) + F.reshape(F.arange(0, n_points), (1, -1, 1)) per_sample_offset = F.reshape( F.arange(0, n_samples) * n_points, (-1, 1, 1)) dst += per_sample_offset src += per_sample_offset dst = F.reshape(dst, (-1, )) src = F.reshape(src, (-1, )) adj = sparse.csr_matrix( (F.asnumpy(F.zeros_like(dst) + 1), (F.asnumpy(dst), F.asnumpy(src)))) g = DGLGraph(adj, readonly=True) return g
def segmented_knn_graph(x, k, segs): """Transforms the given point set to a directed graph, whose coordinates are given as a matrix. The predecessors of each point are its k-nearest neighbors. The matrices are concatenated along the first axis, and are segmented by ``segs``. Each block would be transformed into a separate graph. The graphs will be unioned. Parameters ---------- x : Tensor The input tensor. k : int The number of neighbors segs : iterable of int Number of points of each point set. Must sum up to the number of rows in ``x``. Returns ------- DGLGraph The graph. The node IDs are in the same order as ``x``. """ n_total_points, _ = F.shape(x) offset = np.insert(np.cumsum(segs), 0, 0) h_list = F.split(x, segs, 0) dst = [ F.argtopk(pairwise_squared_distance(h_g), k, 1, descending=False) + offset[i] for i, h_g in enumerate(h_list) ] dst = F.cat(dst, 0) src = F.arange(0, n_total_points).unsqueeze(1).expand(n_total_points, k) dst = F.reshape(dst, (-1, )) src = F.reshape(src, (-1, )) # !!! fix shape adj = sparse.csr_matrix( (F.asnumpy(F.zeros_like(dst) + 1), (F.asnumpy(dst), F.asnumpy(src))), shape=(n_total_points, n_total_points)) g = DGLGraph(adj, readonly=True) return g