Beispiel #1
0
    def test_shards(self):
        d = 32
        nq = 100
        nb = 200

        (_, xb, xq) = make_binary_dataset(d, 0, nb, nq)

        index_ref = faiss.IndexBinaryFlat(d)
        index_ref.add(xb)

        Dref, Iref = index_ref.search(xq, 10)

        nrep = 5
        index = faiss.IndexBinaryShards(d)
        for i in range(nrep):
            sub_idx = faiss.IndexBinaryFlat(d)
            sub_idx.add(xb[i * nb // nrep:(i + 1) * nb // nrep])
            index.add_shard(sub_idx)

        D, I = index.search(xq, 10)

        compare_binary_result_lists(Dref, Iref, D, I)

        index2 = faiss.IndexBinaryShards(d)
        for _i in range(nrep):
            sub_idx = faiss.IndexBinaryFlat(d)
            index2.add_shard(sub_idx)

        index2.add(xb)
        D2, I2 = index2.search(xq, 10)

        compare_binary_result_lists(Dref, Iref, D2, I2)
Beispiel #2
0
    def test_replicas(self):
        d = 32
        nq = 100
        nb = 200

        (_, xb, xq) = make_binary_dataset(d, 0, nb, nq)

        index_ref = faiss.IndexBinaryFlat(d)
        index_ref.add(xb)

        Dref, Iref = index_ref.search(xq, 10)

        nrep = 5
        index = faiss.IndexBinaryReplicas()
        for _i in range(nrep):
            sub_idx = faiss.IndexBinaryFlat(d)
            sub_idx.add(xb)
            index.addIndex(sub_idx)

        D, I = index.search(xq, 10)

        self.assertTrue((Dref == D).all())
        self.assertTrue((Iref == I).all())

        index2 = faiss.IndexBinaryReplicas()
        for _i in range(nrep):
            sub_idx = faiss.IndexBinaryFlat(d)
            index2.addIndex(sub_idx)

        index2.add(xb)
        D2, I2 = index2.search(xq, 10)

        self.assertTrue((Dref == D2).all())
        self.assertTrue((Iref == I2).all())
    def test_wrapped_quantizer(self):
        d = 256
        nt = 150
        nb = 1500
        nq = 500
        (xt, xb, xq) = make_binary_dataset(d, nb, nt, nq)

        nlist = 16
        quantizer_ref = faiss.IndexBinaryFlat(d)
        index_ref = faiss.IndexBinaryIVF(quantizer_ref, d, nlist)
        index_ref.train(xt)

        index_ref.add(xb)

        unwrapped_quantizer = faiss.IndexFlatL2(d)
        quantizer = faiss.IndexBinaryFromFloat(unwrapped_quantizer)
        index = faiss.IndexBinaryIVF(quantizer, d, nlist)

        index.train(xt)

        index.add(xb)

        D_ref, I_ref = index_ref.search(xq, 10)
        D, I = index.search(xq, 10)

        np.testing.assert_array_equal(D_ref, D)
Beispiel #4
0
def search(bit_vectors_test):

    global db_bit_vectors
    bit_vectors = db_bit_vectors

    # set padding
    p = int((math.ceil(bitsize / 8.0)) * 8 - bitsize)
    padding = []
    for i in range(p):
        padding.append(0)

    # similarity search
    #bank_size = 100000
    #q_size = 50000
    #bank_size = 10000
    #q_size = 5000
    d = bitsize + len(padding)
    #k = 5

    # initialize database
    db = bit_vectors[:bank_size]
    #print('Size of database: ', len(db))
    # initialize query set
    queries = bit_vectors_test[:q_size]
    #print('Size of query set: ', len(queries))
    #print('\n')
    # Initializing index.
    index = faiss.IndexBinaryFlat(d)
    # Adding the database vectors.
    index.add(db)
    #print('Searching for nearest neighbors ...')
    D, I = index.search(queries, k)
    print('Search Complete!')
    return I
    def test_wrapped_quantizer_IMI(self):
        d = 256
        nt = 3500
        nb = 10000
        nq = 500
        (xt, xb, xq) = make_binary_dataset(d, nb, nt, nq)

        index_ref = faiss.IndexBinaryFlat(d)

        index_ref.add(xb)

        nlist_exp = 6
        nlist = 2**(2 * nlist_exp)
        float_quantizer = faiss.MultiIndexQuantizer(d, 2, nlist_exp)
        wrapped_quantizer = faiss.IndexBinaryFromFloat(float_quantizer)
        wrapped_quantizer.train(xt)

        assert nlist == float_quantizer.ntotal

        index = faiss.IndexBinaryIVF(wrapped_quantizer, d,
                                     float_quantizer.ntotal)
        index.nprobe = 2048
        assert index.is_trained

        index.add(xb)

        D_ref, I_ref = index_ref.search(xq, 10)
        D, I = index.search(xq, 10)

        recall = sum(gti[0] in Di[:10] for gti, Di in zip(D_ref, D)) \
                 / float(D_ref.shape[0])

        assert recall > 0.82, "recall = %g" % recall
Beispiel #6
0
    def test_binary_flat(self):
        k = 10

        index_ref = faiss.IndexBinaryFlat(self.d_bin)
        index_ref.add(self.xb_bin)
        D_ref, I_ref = index_ref.search(self.xq_bin, k)

        index = faiss.GpuIndexBinaryFlat(faiss.StandardGpuResources(),
                                         self.d_bin)
        index.add(self.xb_bin)
        D, I = index.search(self.xq_bin, k)

        for d_ref, i_ref, d_new, i_new in zip(D_ref, I_ref, D, I):
            # exclude max distance
            assert d_ref.max() == d_new.max()
            dmax = d_ref.max()

            # sort by (distance, id) pairs to be reproducible
            ref = [(d, i) for d, i in zip(d_ref, i_ref) if d < dmax]
            ref.sort()

            new = [(d, i) for d, i in zip(d_new, i_new) if d < dmax]
            new.sort()

            assert ref == new
Beispiel #7
0
    def test_ivf_reconstruction(self):
        d = self.xq.shape[1] * 8
        quantizer = faiss.IndexBinaryFlat(d)
        index = faiss.IndexBinaryIVF(quantizer, d, 8)
        index.cp.min_points_per_centroid = 5    # quiet warning
        index.nprobe = 4
        index.train(self.xt)

        index.add(self.xb)
        index.set_direct_map_type(faiss.DirectMap.Array)

        for i in range(0, len(self.xb), 13):
            np.testing.assert_array_equal(
                index.reconstruct(i),
                self.xb[i]
            )

        # try w/ hashtable
        index = faiss.IndexBinaryIVF(quantizer, d, 8)
        rs = np.random.RandomState(123)
        ids = rs.choice(10000, size=len(self.xb), replace=False).astype(np.int64)
        index.add_with_ids(self.xb, ids)
        index.set_direct_map_type(faiss.DirectMap.Hashtable)

        for i in range(0, len(self.xb), 13):
            np.testing.assert_array_equal(
                index.reconstruct(int(ids[i])),
                self.xb[i]
            )
Beispiel #8
0
    def test_ivf_flat(self):
        d = self.xq.shape[1] * 8

        quantizer = faiss.IndexBinaryFlat(d)
        index = faiss.IndexBinaryIVF(quantizer, d, 8)
        index.cp.min_points_per_centroid = 5  # quiet warning
        index.nprobe = 4
        index.train(self.xt)
        index.add(self.xb)
        D, I = index.search(self.xq, 3)

        _, tmpnam = tempfile.mkstemp()

        try:
            faiss.write_index_binary(index, tmpnam)

            index2 = faiss.read_index_binary(tmpnam)

            D2, I2 = index2.search(self.xq, 3)

            assert (I2 == I).all()
            assert (D2 == D).all()

        finally:
            os.remove(tmpnam)
Beispiel #9
0
    def test_remove_id_map_binary(self):
        sub_index = faiss.IndexBinaryFlat(40)
        xb = np.zeros((10, 5), dtype='uint8')
        xb[:, 0] = np.arange(10) + 100
        index = faiss.IndexBinaryIDMap2(sub_index)
        index.add_with_ids(xb, np.arange(10) + 1000)
        assert index.reconstruct(1004)[0] == 104
        index.remove_ids(np.array([1003]))
        assert index.reconstruct(1004)[0] == 104
        try:
            index.reconstruct(1003)
        except:
            pass
        else:
            assert False, 'should have raised an exception'

        # while we are there, let's test I/O as well...
        _, tmpnam = tempfile.mkstemp()
        try:
            faiss.write_index_binary(index, tmpnam)
            index = faiss.read_index_binary(tmpnam)
        finally:
            os.remove(tmpnam)

        assert index.reconstruct(1004)[0] == 104
        try:
            index.reconstruct(1003)
        except:
            pass
        else:
            assert False, 'should have raised an exception'
    def test_wrapped_quantizer_HNSW(self):
        faiss.omp_set_num_threads(1)

        def bin2float(v):
            def byte2float(byte):
                return np.array(
                    [-1.0 + 2.0 * (byte & (1 << b) != 0) for b in range(0, 8)])

            return np.hstack([byte2float(byte)
                              for byte in v]).astype('float32')

        def floatvec2nparray(v):
            return np.array([np.float32(v.at(i)) for i in range(0, v.size())]) \
                     .reshape(-1, d)

        d = 256
        nt = 12800
        nb = 10000
        nq = 500
        (xt, xb, xq) = make_binary_dataset(d, nb, nt, nq)

        index_ref = faiss.IndexBinaryFlat(d)

        index_ref.add(xb)

        nlist = 256
        clus = faiss.Clustering(d, nlist)
        clus_index = faiss.IndexFlatL2(d)

        xt_f = np.array([bin2float(v) for v in xt])
        clus.train(xt_f, clus_index)

        centroids = floatvec2nparray(clus.centroids)
        hnsw_quantizer = faiss.IndexHNSWFlat(d, 32)
        hnsw_quantizer.add(centroids)
        hnsw_quantizer.is_trained = True
        wrapped_quantizer = faiss.IndexBinaryFromFloat(hnsw_quantizer)

        assert nlist == hnsw_quantizer.ntotal
        assert nlist == wrapped_quantizer.ntotal
        assert wrapped_quantizer.is_trained

        index = faiss.IndexBinaryIVF(wrapped_quantizer, d,
                                     hnsw_quantizer.ntotal)
        index.nprobe = 128

        assert index.is_trained

        index.add(xb)

        D_ref, I_ref = index_ref.search(xq, 10)
        D, I = index.search(xq, 10)

        recall = sum(gti[0] in Di[:10] for gti, Di in zip(D_ref, D)) \
                 / float(D_ref.shape[0])

        assert recall > 0.77, "recall = %g" % recall
Beispiel #11
0
    def test_replicas(self):
        d = 32
        nq = 100
        nb = 200

        (_, xb, xq) = make_binary_dataset(d, 0, nb, nq)

        index_ref = faiss.IndexBinaryFlat(d)
        index_ref.add(xb)

        Dref, Iref = index_ref.search(xq, 10)

        # there is a OpenMP bug in this configuration, so disable threading
        if sys.platform == "darwin" and "Clang 12" in sys.version:
            nthreads = faiss.omp_get_max_threads()
            faiss.omp_set_num_threads(1)
        else:
            nthreads = None

        nrep = 5
        index = faiss.IndexBinaryReplicas()
        for _i in range(nrep):
            sub_idx = faiss.IndexBinaryFlat(d)
            sub_idx.add(xb)
            index.addIndex(sub_idx)

        D, I = index.search(xq, 10)

        self.assertTrue((Dref == D).all())
        self.assertTrue((Iref == I).all())

        index2 = faiss.IndexBinaryReplicas()
        for _i in range(nrep):
            sub_idx = faiss.IndexBinaryFlat(d)
            index2.addIndex(sub_idx)

        index2.add(xb)
        D2, I2 = index2.search(xq, 10)

        if nthreads is not None:
            faiss.omp_set_num_threads(nthreads)

        self.assertTrue((Dref == D2).all())
        self.assertTrue((Iref == I2).all())
Beispiel #12
0
    def test_empty_flat(self):
        d = self.xq.shape[1] * 8

        index = faiss.IndexBinaryFlat(d)

        for use_heap in [True, False]:
            index.use_heap = use_heap
            Dflat, Iflat = index.search(self.xq, 10)

            assert (np.all(Iflat == -1))
            assert (np.all(Dflat == 2147483647))  # NOTE(hoss): int32_t max
Beispiel #13
0
    def __init__(self, *args, **kwargs):
        unittest.TestCase.__init__(self, *args, **kwargs)
        d = 32
        nt = 200
        nb = 1500
        nq = 500

        (self.xt, self.xb, self.xq) = make_binary_dataset(d, nt, nb, nq)
        index = faiss.IndexBinaryFlat(d)
        index.add(self.xb)
        Dref, Iref = index.search(self.xq, 10)
        self.Dref = Dref
Beispiel #14
0
    def test_ivf_flat_exhaustive(self):
        d = self.xq.shape[1] * 8

        quantizer = faiss.IndexBinaryFlat(d)
        index = faiss.IndexBinaryIVF(quantizer, d, 8)
        index.cp.min_points_per_centroid = 5  # quiet warning
        index.nprobe = 8
        index.train(self.xt)
        index.add(self.xb)
        Divfflat, _ = index.search(self.xq, 10)

        np.testing.assert_array_equal(self.Dref, Divfflat)
Beispiel #15
0
    def test_ivf_flat2(self):
        d = self.xq.shape[1] * 8

        quantizer = faiss.IndexBinaryFlat(d)
        index = faiss.IndexBinaryIVF(quantizer, d, 8)
        index.cp.min_points_per_centroid = 5  # quiet warning
        index.nprobe = 4
        index.train(self.xt)
        index.add(self.xb)
        Divfflat, _ = index.search(self.xq, 10)

        self.assertEqual((self.Dref == Divfflat).sum(), 4122)
Beispiel #16
0
    def test_ivf_flat_empty(self):
        d = self.xq.shape[1] * 8

        index = faiss.IndexBinaryIVF(faiss.IndexBinaryFlat(d), d, 8)
        index.train(self.xt)

        for use_heap in [True, False]:
            index.use_heap = use_heap
            Divfflat, Iivfflat = index.search(self.xq, 10)

            assert (np.all(Iivfflat == -1))
            assert (np.all(Divfflat == 2147483647))  # NOTE(hoss): int32_t max
Beispiel #17
0
    def test_flat(self):
        d = self.xq.shape[1] * 8
        nq = self.xq.shape[0]

        index = faiss.IndexBinaryFlat(d)
        index.add(self.xb)
        D, I = index.search(self.xq, 3)

        for i in range(nq):
            for j, dj in zip(I[i], D[i]):
                ref_dis = binary_dis(self.xq[i], self.xb[j])
                assert dj == ref_dis
Beispiel #18
0
def train():
    all_data = np.array(get_all_data())
    if len(all_data) == 0:
        print("No images. exit()")
        exit()
    d = 32 * 8
    centroids = round(sqrt(all_data.shape[0]))
    print(f'centroids: {centroids}')
    quantizer = faiss.IndexBinaryFlat(d)
    index = faiss.IndexBinaryIVF(quantizer, d, centroids)
    index.nprobe = 8
    index.train(all_data)
    faiss.write_index_binary(index, "./" + "trained_import.index")
Beispiel #19
0
    def build(
        cls,
        passage_ids: List[int],
        passage_embeddings: np.ndarray,
        index: Optional[faiss.Index] = None,
        buffer_size: int = 50000,
    ):
        if index is None:
            index = faiss.IndexBinaryFlat(passage_embeddings.shape[1] * 8)
        for start in trange(0, len(passage_ids), buffer_size):
            index.add(passage_embeddings[start:start + buffer_size])

        return cls(index, passage_ids, passage_embeddings)
    def test_wrapped_quantizer_HNSW(self):
        def bin2float2d(v):
            n, d = v.shape
            vf = ((v.reshape(-1, 1) >> np.arange(8)) & 1).astype("float32")
            vf *= 2
            vf -= 1
            return vf.reshape(n, d * 8)

        d = 256
        nt = 12800
        nb = 10000
        nq = 500
        (xt, xb, xq) = make_binary_dataset(d, nb, nt, nq)

        index_ref = faiss.IndexBinaryFlat(d)

        index_ref.add(xb)

        nlist = 256
        clus = faiss.Clustering(d, nlist)
        clus_index = faiss.IndexFlatL2(d)

        xt_f = bin2float2d(xt)
        clus.train(xt_f, clus_index)

        centroids = faiss.vector_to_array(clus.centroids).reshape(-1, clus.d)
        hnsw_quantizer = faiss.IndexHNSWFlat(d, 32)
        hnsw_quantizer.add(centroids)
        hnsw_quantizer.is_trained = True
        wrapped_quantizer = faiss.IndexBinaryFromFloat(hnsw_quantizer)

        assert nlist == hnsw_quantizer.ntotal
        assert nlist == wrapped_quantizer.ntotal
        assert wrapped_quantizer.is_trained

        index = faiss.IndexBinaryIVF(wrapped_quantizer, d,
                                     hnsw_quantizer.ntotal)
        index.nprobe = 128

        assert index.is_trained

        index.add(xb)

        D_ref, I_ref = index_ref.search(xq, 10)
        D, I = index.search(xq, 10)

        recall = sum(gti[0] in Di[:10] for gti, Di in zip(D_ref, D)) \
                 / float(D_ref.shape[0])

        assert recall >= 0.77, "recall = %g" % recall
Beispiel #21
0
    def test_ivf_flat2(self):
        d = self.xq.shape[1] * 8

        quantizer = faiss.IndexBinaryFlat(d)
        index = faiss.IndexBinaryIVF(quantizer, d, 8)
        index.cp.min_points_per_centroid = 5    # quiet warning
        index.nprobe = 4
        index.train(self.xt)
        index.add(self.xb)
        Divfflat, _ = index.search(self.xq, 10)

        # Some centroids are equidistant from the query points.
        # So the answer will depend on the implementation of the heap.
        self.assertGreater((self.Dref == Divfflat).sum(), 4100)
Beispiel #22
0
 def build_index(cls, feature_file, index_file):
     '''
     :params feature_file: a npy file generated by using utils.build_mol_features
     '''
     logging.info("rebuild index from {}".format(feature_file))
     fp_arr = np.load(feature_file)
     bytes_list = []
     for item in tqdm.tqdm(fp_arr):
         bytes_list.append(cls.vec2bytes(item))
     dim = int(np.ceil(fp_arr.shape[1] / 8) * 8)
     index = faiss.IndexBinaryFlat(dim)
     index.add(np.array(bytes_list).astype("uint8"))
     faiss.write_index_binary(index, index_file)
     return index
Beispiel #23
0
    def test_read_index_ownership(self):
        d = self.xq.shape[1] * 8

        index = faiss.IndexBinaryFlat(d)
        index.add(self.xb)

        _, tmpnam = tempfile.mkstemp()
        try:
            faiss.write_index_binary(index, tmpnam)

            index2 = faiss.read_index_binary(tmpnam)

            assert index2.thisown
        finally:
            os.remove(tmpnam)
Beispiel #24
0
def init_index():
    global index
    try:
        index = faiss.read_index_binary("trained.index")
    except:
        d = 32 * 8
        quantizer = faiss.IndexBinaryFlat(d)
        index = faiss.IndexBinaryIVF(quantizer, d, 1)
        index.nprobe = 1
        index.train(np.array([np.zeros(32)], dtype=np.uint8))
    all_data = get_all_data()
    image_ids = np.array([np.int64(x[0]) for x in all_data])
    phashes = np.array([x[1] for x in all_data])
    if len(all_data) != 0:
        index.add_with_ids(phashes, image_ids)
    print("Index is ready")
Beispiel #25
0
    def test_hash_and_multihash(self):
        d = 128
        nq = 100
        nb = 2000

        (_, xb, xq) = make_binary_dataset(d, 0, nb, nq)

        index_ref = faiss.IndexBinaryFlat(d)
        index_ref.add(xb)
        k = 10
        Dref, Iref = index_ref.search(xq, k)

        nfound = {}
        for nh in 0, 1, 3, 5:

            for nbit in 4, 7:
                if nh == 0:
                    index = faiss.IndexBinaryHash(d, nbit)
                else:
                    index = faiss.IndexBinaryMultiHash(d, nh, nbit)
                index.add(xb)
                index.nflip = 2
                Dnew, Inew = index.search(xq, k)
                nf = 0
                for i in range(nq):
                    ref = Iref[i]
                    new = Inew[i]
                    snew = set(new)
                    # no duplicates
                    self.assertTrue(len(new) == len(snew))
                    nf += len(set(ref) & snew)
                print('nfound', nh, nbit, nf)
                nfound[(nh, nbit)] = nf
            self.assertGreater(nfound[(nh, 4)], nfound[(nh, 7)])

            # test serialization
            index2 = faiss.deserialize_index_binary(
                faiss.serialize_index_binary(index))

            D2, I2 = index2.search(xq, k)
            np.testing.assert_array_equal(Inew, I2)
            np.testing.assert_array_equal(Dnew, D2)

        print('nfound=', nfound)
        self.assertGreater(3, abs(nfound[(0, 7)] - nfound[(1, 7)]))
        self.assertGreater(nfound[(3, 7)], nfound[(1, 7)])
        self.assertGreater(nfound[(5, 7)], nfound[(3, 7)])
Beispiel #26
0
def train():
    all_descriptors=[]
    all_data=import_get_all_data()
    if len(all_data)==0:
        print("No images. exit()")
        exit()
    for x in all_data:
        all_descriptors.append(x[1])
    all_descriptors=np.concatenate(all_descriptors, axis=0)

    d=61*8
    centroids = round(sqrt(all_descriptors.shape[0]))
    print(f'centroids: {centroids}')
    quantizer = faiss.IndexBinaryFlat(d)
    index = faiss.IndexBinaryIVF(quantizer, d, centroids)
    index.nprobe = 8
    index.train(all_descriptors)
    faiss.write_index_binary(index, "./" + "trained_import.index")
Beispiel #27
0
def train():
    all_descriptors = []
    all_ids = get_all_ids()
    if len(all_ids) == 0:
        print("No images. exit()")
        exit()
    for id in all_ids:
        x = convert_array(get_akaze_features_by_id(id))
        all_descriptors.append(x)
    all_descriptors = np.concatenate(all_descriptors, axis=0)

    d = 61 * 8
    centroids = round(sqrt(all_descriptors.shape[0]))
    print(f'centroids: {centroids}')
    quantizer = faiss.IndexBinaryFlat(d)
    index = faiss.IndexBinaryIVF(quantizer, d, centroids)
    index.nprobe = 8
    index.train(all_descriptors)
    faiss.write_index_binary(index, "./" + "trained.index")
Beispiel #28
0
    def test_multihash(self):
        d = 128
        nq = 100
        nb = 2000

        (_, xb, xq) = make_binary_dataset(d, 0, nb, nq)

        index_ref = faiss.IndexBinaryFlat(d)
        index_ref.add(xb)

        radius = 55

        Lref, Dref, Iref = index_ref.range_search(xq, radius)

        print("nb res: ", Lref[-1])

        nfound = []
        ndis = []

        for nh in 1, 3, 5:
            index = faiss.IndexBinaryMultiHash(d, nh, 10)
            index.add(xb)
            # index.display()
            stats = faiss.cvar.indexBinaryHash_stats
            index.nflip = 2
            stats.reset()
            Lnew, Dnew, Inew = index.range_search(xq, radius)
            for i in range(nq):
                ref = Iref[Lref[i]:Lref[i + 1]]
                new = Inew[Lnew[i]:Lnew[i + 1]]
                snew = set(new)
                # no duplicates
                self.assertTrue(len(new) == len(snew))
                # subset of real results
                self.assertTrue(snew <= set(ref))
            nfound.append(Lnew[-1])
            ndis.append(stats.ndis)
        print('nfound=', nfound)
        print('ndis=', ndis)
        nfound = np.array(nfound)
        # self.assertTrue(nfound[-1] == Lref[-1])
        self.assertTrue(np.all(nfound[1:] >= nfound[:-1]))
Beispiel #29
0
    def test_flat(self):
        d = self.xq.shape[1] * 8

        index = faiss.IndexBinaryFlat(d)
        index.add(self.xb)
        D, I = index.search(self.xq, 3)

        _, tmpnam = tempfile.mkstemp()
        try:
            faiss.write_index_binary(index, tmpnam)

            index2 = faiss.read_index_binary(tmpnam)

            D2, I2 = index2.search(self.xq, 3)

            assert (I2 == I).all()
            assert (D2 == D).all()

        finally:
            os.remove(tmpnam)
Beispiel #30
0
    def test_ivf_range(self):
        d = self.xq.shape[1] * 8

        quantizer = faiss.IndexBinaryFlat(d)
        index = faiss.IndexBinaryIVF(quantizer, d, 8)
        index.cp.min_points_per_centroid = 5    # quiet warning
        index.nprobe = 4
        index.train(self.xt)
        index.add(self.xb)
        D, I = index.search(self.xq, 10)

        radius = int(np.median(D[:, -1]) + 1)
        Lr, Dr, Ir = index.range_search(self.xq, radius)

        for i in range(len(self.xq)):
            res = Ir[Lr[i]:Lr[i + 1]]
            if D[i, -1] < radius:
                self.assertTrue(set(I[i]) <= set(res))
            else:
                subset = I[i, D[i, :] < radius]
                self.assertTrue(set(subset) == set(res))