예제 #1
0
파일: test_n2.py 프로젝트: RossSong/n2
 def test04_batch_search_by_vectors(self):
     index = HnswIndex(self.dim)
     index.load(self.model_fname)
     T = [[random.gauss(0, 1) for z in xrange(self.dim)]
          for y in xrange(100)]
     batch_res = index.batch_search_by_vectors(T,
                                               10,
                                               num_threads=12,
                                               include_distances=True)
     normal_res = [
         index.search_by_vector(t, 10, include_distances=True) for t in T
     ]
     self.assertEqual(batch_res, normal_res)
예제 #2
0
class N2(BaseANN):
    def __init__(self, m, ef_construction, n_threads, ef_search, metric, batch):
        self.name = "N2_M%d_efCon%d_n_thread%s_efSearch%d%s" % (m, ef_construction, n_threads, ef_search,
                                                                '_batch' if batch else '')
        self._m = m
        self._m0 = m * 2
        self._ef_construction = ef_construction
        self._n_threads = n_threads
        self._ef_search = ef_search
        self._index_name = os.path.join(CACHE_DIR, "index_n2_%s_M%d_efCon%d_n_thread%s"
                                        % (args.dataset, m, ef_construction, n_threads))
        self._metric = metric

    def fit(self, X):
        if self._metric == 'euclidean':
            self._n2 = HnswIndex(X.shape[1], 'L2')
        elif self._metric == 'dot':
            self._n2 = HnswIndex(X.shape[1], 'dot')
        else:
            self._n2 = HnswIndex(X.shape[1])

        if os.path.exists(self._index_name):
            n2_logger.info("Loading index from file")
            self._n2.load(self._index_name, use_mmap=False)
            return

        n2_logger.info("Create Index")
        for i, x in enumerate(X):
            self._n2.add_data(x)
        self._n2.build(m=self._m, max_m0=self._m0, ef_construction=self._ef_construction, n_threads=self._n_threads)
        self._n2.save(self._index_name)

    def query(self, v, n):
        return self._n2.search_by_vector(v, n, self._ef_search)

    def batch_query(self, X, n):
        self.b_res = self._n2.batch_search_by_vectors(X, n, self._ef_search, self._n_threads)

    def get_batch_results(self):
        return self.b_res

    def __str__(self):
        return self.name