def search_knn(xq, xb, k, distance_type=faiss.METRIC_L2): """ wrapper around the faiss knn functions without index """ nq, d = xq.shape nb, d2 = xb.shape assert d == d2 I = np.empty((nq, k), dtype='int64') D = np.empty((nq, k), dtype='float32') if distance_type == faiss.METRIC_L2: heaps = faiss.float_maxheap_array_t() heaps.k = k heaps.nh = nq heaps.val = faiss.swig_ptr(D) heaps.ids = faiss.swig_ptr(I) faiss.knn_L2sqr(faiss.swig_ptr(xq), faiss.swig_ptr(xb), d, nq, nb, heaps) elif distance_type == faiss.METRIC_INNER_PRODUCT: heaps = faiss.float_minheap_array_t() heaps.k = k heaps.nh = nq heaps.val = faiss.swig_ptr(D) heaps.ids = faiss.swig_ptr(I) faiss.knn_inner_product(faiss.swig_ptr(xq), faiss.swig_ptr(xb), d, nq, nb, heaps) return D, I
def _knn_search(queries, data, k, return_neighbours=False, res=None): num_queries, dim = queries.shape if res is None: dists, idxs = np.empty((num_queries, k), dtype=np.float32), np.empty( (num_queries, k), dtype=np.int64) heaps = faiss.float_maxheap_array_t() heaps.k, heaps.nh = k, num_queries heaps.val, heaps.ids = faiss.swig_ptr(dists), faiss.swig_ptr(idxs) faiss.knn_L2sqr(faiss.swig_ptr(queries), faiss.swig_ptr(data), dim, num_queries, data.shape[0], heaps) else: dists, idxs = torch.empty(num_queries, k, dtype=torch.float32, device=queries.device), torch.empty( num_queries, k, dtype=torch.int64, device=queries.device) faiss.bruteForceKnn( res, faiss.METRIC_L2, faiss.cast_integer_to_float_ptr(data.storage().data_ptr() + data.storage_offset() * 4), data.is_contiguous(), data.shape[0], faiss.cast_integer_to_float_ptr(queries.storage().data_ptr() + queries.storage_offset() * 4), queries.is_contiguous(), num_queries, dim, k, faiss.cast_integer_to_float_ptr(dists.storage().data_ptr() + dists.storage_offset() * 4), faiss.cast_integer_to_long_ptr(idxs.storage().data_ptr() + idxs.storage_offset() * 8)) if return_neighbours: neighbours = data[idxs.reshape(-1)].reshape(-1, k, dim) return dists, idxs, neighbours else: return dists, idxs