def test04_batch_search_by_vectors(self): index = HnswIndex(self.dim) index.load(self.model_fname) T = [[random.gauss(0, 1) for z in xrange(self.dim)] for y in xrange(100)] batch_res = index.batch_search_by_vectors(T, 10, num_threads=12, include_distances=True) normal_res = [ index.search_by_vector(t, 10, include_distances=True) for t in T ] self.assertEqual(batch_res, normal_res)
class N2(BaseANN): def __init__(self, m, ef_construction, n_threads, ef_search, metric, batch): self.name = "N2_M%d_efCon%d_n_thread%s_efSearch%d%s" % (m, ef_construction, n_threads, ef_search, '_batch' if batch else '') self._m = m self._m0 = m * 2 self._ef_construction = ef_construction self._n_threads = n_threads self._ef_search = ef_search self._index_name = os.path.join(CACHE_DIR, "index_n2_%s_M%d_efCon%d_n_thread%s" % (args.dataset, m, ef_construction, n_threads)) self._metric = metric def fit(self, X): if self._metric == 'euclidean': self._n2 = HnswIndex(X.shape[1], 'L2') elif self._metric == 'dot': self._n2 = HnswIndex(X.shape[1], 'dot') else: self._n2 = HnswIndex(X.shape[1]) if os.path.exists(self._index_name): n2_logger.info("Loading index from file") self._n2.load(self._index_name, use_mmap=False) return n2_logger.info("Create Index") for i, x in enumerate(X): self._n2.add_data(x) self._n2.build(m=self._m, max_m0=self._m0, ef_construction=self._ef_construction, n_threads=self._n_threads) self._n2.save(self._index_name) def query(self, v, n): return self._n2.search_by_vector(v, n, self._ef_search) def batch_query(self, X, n): self.b_res = self._n2.batch_search_by_vectors(X, n, self._ef_search, self._n_threads) def get_batch_results(self): return self.b_res def __str__(self): return self.name