Beispiel #1
0
    def subtest(self, d, metric, topk, search_L, threshold):
        metric_names = {faiss.METRIC_L1: 'L1',
                        faiss.METRIC_L2: 'L2',
                        faiss.METRIC_INNER_PRODUCT: 'IP'}
        topk = 10
        nt, nb, nq = 2000, 1000, 200
        xt, xb, xq = get_dataset_2(d, nt, nb, nq)
        gt_index = faiss.IndexFlat(d, metric)
        gt_index.add(xb)
        gt_D, gt_I = gt_index.search(xq, topk)

        K = 16
        index = faiss.IndexNNDescentFlat(d, K, metric)
        index.nndescent.S = 10
        index.nndescent.R = 32
        index.nndescent.L = K + 20
        index.nndescent.iter = 5
        index.verbose = False

        index.nndescent.search_L = search_L;

        index.add(xb)
        D, I = index.search(xq, topk)
        recalls = 0
        for i in range(nq):
            for j in range(topk):
                for k in range(topk):
                    if I[i, j] == gt_I[i, k]:
                        recalls += 1
                        break
        recall = 1.0 * recalls / (nq * topk)
        print('Metric: {}, L: {}, Recall@{}: {}'.format(
            metric_names[metric], search_L, topk, recall))
        assert recall > threshold, '{} <= {}'.format(recall, threshold)
Beispiel #2
0
    def subtest(self, d, K, metric):
        metric_names = {faiss.METRIC_L1: 'L1',
                        faiss.METRIC_L2: 'L2',
                        faiss.METRIC_INNER_PRODUCT: 'IP'}

        nb = 1000
        _, xb, _ = get_dataset_2(d, 0, nb, 0)

        _, knn = faiss.knn(xb, xb, K + 1, metric)
        knn = knn[:, 1:]

        index = faiss.IndexNNDescentFlat(d, K, metric)
        index.nndescent.S = 10
        index.nndescent.R = 32
        index.nndescent.L = K + 20
        index.nndescent.iter = 5
        index.verbose = True

        index.add(xb)
        graph = index.nndescent.final_graph
        graph = faiss.vector_to_array(graph)
        graph = graph.reshape(nb, K)

        recalls = 0
        for i in range(nb):
            for j in range(K):
                for k in range(K):
                    if graph[i, j] == knn[i, k]:
                        recalls += 1
                        break
        recall = 1.0 * recalls / (nb * K)
        print('Metric: {}, knng accuracy: {}'.format(metric_names[metric], recall))
        assert recall > 0.99