Beispiel #1
0
    def test_raw_array_search(self):
        d = 32
        nb = 1024
        nq = 128
        k = 10

        # make GT on Faiss CPU

        xq = faiss.randn(nq * d, 1234).reshape(nq, d)
        xb = faiss.randn(nb * d, 1235).reshape(nb, d)

        index = faiss.IndexFlatL2(d)
        index.add(xb)
        gt_D, gt_I = index.search(xq, k)

        # move to pytorch & GPU
        xq_t = torch.from_numpy(xq).cuda()
        xb_t = torch.from_numpy(xb).cuda()

        # resource object, can be re-used over calls
        res = faiss.StandardGpuResources()

        D, I = search_raw_array_pytorch(res, xb_t, xq_t, k)

        # back to CPU for verification
        D = D.cpu().numpy()
        I = I.cpu().numpy()

        assert np.all(I == gt_I)
        assert np.all(np.abs(D - gt_D).max() < 1e-4)
Beispiel #2
0
    def test_interop(self):

        d = 16
        nq = 5
        nb = 20

        xq = faiss.randn(nq * d, 1234).reshape(nq, d)
        xb = faiss.randn(nb * d, 1235).reshape(nb, d)

        res = faiss.StandardGpuResources()
        index = faiss.GpuIndexFlatIP(res, d)
        index.add(xb)

        # reference CPU result
        Dref, Iref = index.search(xq, 5)

        # query is pytorch tensor (CPU)
        xq_torch = torch.FloatTensor(xq)

        D2, I2 = search_index_pytorch(index, xq_torch, 5)

        assert np.all(Iref == I2.numpy())

        # query is pytorch tensor (GPU)
        xq_torch = xq_torch.cuda()
        # no need for a sync here

        D3, I3 = search_index_pytorch(index, xq_torch, 5)

        # D3 and I3 are on torch tensors on GPU as well.
        # this does a sync, which is useful because faiss and
        # pytorch use different Cuda streams.
        res.syncDefaultStreamCurrentDevice()

        assert np.all(Iref == I3.cpu().numpy())
Beispiel #3
0
    def test_interop(self):
        d = 128
        nq = 100
        nb = 1000
        k = 10

        xq = faiss.randn(nq * d, 1234).reshape(nq, d)
        xb = faiss.randn(nb * d, 1235).reshape(nb, d)

        res = faiss.StandardGpuResources()

        # Let's run on a non-default stream
        s = torch.cuda.Stream()

        # Torch will run on this stream
        with torch.cuda.stream(s):
            # query is pytorch tensor (CPU and GPU)
            xq_torch_cpu = torch.FloatTensor(xq)
            xq_torch_gpu = xq_torch_cpu.cuda()

            index = faiss.GpuIndexFlatIP(res, d)
            index.add(xb)

            # Query with GPU tensor (this will be done on the current pytorch stream)
            D2, I2 = search_index_pytorch(res, index, xq_torch_gpu, k)
            Dref, Iref = index.search(xq, k)

            assert np.all(Iref == I2.cpu().numpy())

            # Query with CPU tensor
            D3, I3 = search_index_pytorch(res, index, xq_torch_cpu, k)

            assert np.all(Iref == I3.numpy())
    def test_interop(self):

        d = 16
        nq = 5
        nb = 20

        xq = faiss.randn(nq * d, 1234).reshape(nq, d)
        xb = faiss.randn(nb * d, 1235).reshape(nb, d)

        res = faiss.StandardGpuResources()
        index = faiss.GpuIndexFlatIP(res, d)
        index.add(xb)

        # reference CPU result
        Dref, Iref = index.search(xq, 5)

        # query is pytorch tensor (CPU)
        xq_torch = torch.FloatTensor(xq)

        D2, I2 = search_index_pytorch(index, xq_torch, 5)

        assert np.all(Iref == I2.numpy())

        # query is pytorch tensor (GPU)
        xq_torch = xq_torch.cuda()
        # no need for a sync here

        D3, I3 = search_index_pytorch(index, xq_torch, 5)

        # D3 and I3 are on torch tensors on GPU as well.
        # this does a sync, which is useful because faiss and
        # pytorch use different Cuda streams.
        res.syncDefaultStreamCurrentDevice()

        assert np.all(Iref == I3.cpu().numpy())
    def test_raw_array_search(self):
        d = 32
        nb = 1024
        nq = 128
        k = 10

        # make GT on Faiss CPU

        xq = faiss.randn(nq * d, 1234).reshape(nq, d)
        xb = faiss.randn(nb * d, 1235).reshape(nb, d)

        index = faiss.IndexFlatL2(d)
        index.add(xb)
        gt_D, gt_I = index.search(xq, k)

        # resource object, can be re-used over calls
        res = faiss.StandardGpuResources()
        # put on same stream as pytorch to avoid synchronizing streams
        res.setDefaultNullStreamAllDevices()

        for xq_row_major in True, False:
            for xb_row_major in True, False:

                # move to pytorch & GPU
                xq_t = torch.from_numpy(xq).cuda()
                xb_t = torch.from_numpy(xb).cuda()

                if not xq_row_major:
                    xq_t = xq_t.t().clone().t()
                    assert not xq_t.is_contiguous()

                if not xb_row_major:
                    xb_t = xb_t.t().clone().t()
                    assert not xb_t.is_contiguous()

                D, I = search_raw_array_pytorch(res, xb_t, xq_t, k)

                # back to CPU for verification
                D = D.cpu().numpy()
                I = I.cpu().numpy()

                assert np.all(I == gt_I)
                assert np.all(np.abs(D - gt_D).max() < 1e-4)

                # test on subset
                try:
                    D, I = search_raw_array_pytorch(res, xb_t, xq_t[60:80], k)
                except TypeError:
                    if not xq_row_major:
                        # then it is expected
                        continue
                    # otherwise it is an error
                    raise

                # back to CPU for verification
                D = D.cpu().numpy()
                I = I.cpu().numpy()

                assert np.all(I == gt_I[60:80])
                assert np.all(np.abs(D - gt_D[60:80]).max() < 1e-4)
Beispiel #6
0
    def do_test(self, index_key):
        d = 32
        index = faiss.index_factory(d, index_key)
        index.train(faiss.randn((100, d), 123))

        # reference reconstruction
        index.add(faiss.randn((100, d), 345))
        index.add(faiss.randn((100, d), 678))
        ref_recons = index.reconstruct_n(0, 200)

        # with lookup
        index.reset()
        rs = np.random.RandomState(123)
        ids = rs.choice(10000, size=200, replace=False).astype(np.int64)
        index.add_with_ids(faiss.randn((100, d), 345), ids[:100])
        index.set_direct_map_type(faiss.DirectMap.Hashtable)
        index.add_with_ids(faiss.randn((100, d), 678), ids[100:])

        # compare
        for i in range(0, 200, 13):
            recons = index.reconstruct(int(ids[i]))
            self.assertTrue(np.all(recons == ref_recons[i]))

        # test I/O
        buf = faiss.serialize_index(index)
        index2 = faiss.deserialize_index(buf)

        # compare
        for i in range(0, 200, 13):
            recons = index2.reconstruct(int(ids[i]))
            self.assertTrue(np.all(recons == ref_recons[i]))

        # remove
        toremove = np.ascontiguousarray(ids[0:200:3])

        sel = faiss.IDSelectorArray(50, faiss.swig_ptr(toremove[:50]))

        # test both ways of removing elements
        nremove = index2.remove_ids(sel)
        nremove += index2.remove_ids(toremove[50:])

        self.assertEqual(nremove, len(toremove))

        for i in range(0, 200, 13):
            if i % 3 == 0:
                self.assertRaises(
                    RuntimeError,
                    index2.reconstruct, int(ids[i])
                )
            else:
                recons = index2.reconstruct(int(ids[i]))
                self.assertTrue(np.all(recons == ref_recons[i]))

        # index error should raise
        self.assertRaises(
            RuntimeError,
            index.reconstruct, 20000
        )
Beispiel #7
0
    def do_test(self, d, dsub, nbit=8, metric=None):
        if metric is None:
            self.do_test(d, dsub, nbit, faiss.METRIC_INNER_PRODUCT)
            self.do_test(d, dsub, nbit, faiss.METRIC_L2)
            return
        # faiss.cvar.distance_compute_blas_threshold = 1000000

        M = d // dsub
        pq = faiss.ProductQuantizer(d, M, nbit)
        xt = faiss.randn((max(1000, pq.ksub * 50), d), 123)
        pq.cp.niter = 4  # to avoid timeouts in tests
        pq.train(xt)

        centroids = faiss.vector_to_array(pq.centroids)
        centroids = centroids.reshape(pq.M, pq.ksub, pq.dsub)

        nx = 100
        x = faiss.randn((nx, d), 555)

        ref_tab = np.zeros((nx, M, pq.ksub), "float32")

        # computation of tables in numpy
        for sq in range(M):
            i0, i1 = sq * dsub, (sq + 1) * dsub
            xsub = x[:, i0:i1]
            centsq = centroids[sq, :, :]
            if metric == faiss.METRIC_INNER_PRODUCT:
                ref_tab[:, sq, :] = xsub @ centsq.T
            elif metric == faiss.METRIC_L2:
                xsub3 = xsub.reshape(nx, 1, dsub)
                cent3 = centsq.reshape(1, pq.ksub, dsub)
                ref_tab[:, sq, :] = ((xsub3 - cent3)**2).sum(2)
            else:
                assert False

        sp = faiss.swig_ptr

        new_tab = np.zeros((nx, M, pq.ksub), "float32")
        if metric == faiss.METRIC_INNER_PRODUCT:
            pq.compute_inner_prod_tables(nx, sp(x), sp(new_tab))
        elif metric == faiss.METRIC_L2:
            pq.compute_distance_tables(nx, sp(x), sp(new_tab))
        else:
            assert False

        # compute sdc tables in numpy
        cent1 = np.expand_dims(centroids, axis=2)  # [M, ksub, 1, dsub]
        cent2 = np.expand_dims(centroids, axis=1)  # [M, 1, ksub, dsub]
        ref_sdc_tab = ((cent1 - cent2)**2).sum(3)

        pq.compute_sdc_table()
        new_sdc_tab = faiss.vector_to_array(pq.sdc_table)
        new_sdc_tab = new_sdc_tab.reshape(M, pq.ksub, pq.ksub)

        np.testing.assert_array_almost_equal(ref_tab, new_tab, decimal=5)
        np.testing.assert_array_almost_equal(ref_sdc_tab,
                                             new_sdc_tab,
                                             decimal=5)
Beispiel #8
0
    def test_reconstuct_after_add(self):
        index = faiss.index_factory(10, 'IVF5,SQfp16')
        index.train(faiss.randn((100, 10), 123))
        index.add(faiss.randn((100, 10), 345))
        index.make_direct_map()
        index.add(faiss.randn((100, 10), 678))

        # should not raise an exception
        index.reconstruct(5)
        print(index.ntotal)
        index.reconstruct(150)
Beispiel #9
0
    def test_weighted(self):
        d = 32
        sigma = 0.1

        # Data is naturally clustered in 10 clusters.
        # 5 clusters have 100 points
        # 5 clusters have 10 points
        # run k-means with 5 clusters

        ccent = faiss.randn((10, d), 123)
        faiss.normalize_L2(ccent)
        x = [
            ccent[i] + sigma * faiss.randn((100, d), 1234 + i)
            for i in range(5)
        ]
        x += [
            ccent[i] + sigma * faiss.randn((10, d), 1234 + i)
            for i in range(5, 10)
        ]
        x = np.vstack(x)

        clus = faiss.Clustering(d, 5)
        index = faiss.IndexFlatL2(d)
        clus.train(x, index)
        cdis1, perm1 = index.search(ccent, 1)

        # distance^2 of ground-truth centroids to clusters
        cdis1_first = cdis1[:5].sum()
        cdis1_last = cdis1[5:].sum()

        # now assign weight 0.1 to the 5 first clusters and weight 10
        # to the 5 last ones and re-run k-means
        weights = np.ones(100 * 5 + 10 * 5, dtype='float32')
        weights[:100 * 5] = 0.1
        weights[100 * 5:] = 10

        clus = faiss.Clustering(d, 5)
        index = faiss.IndexFlatL2(d)
        clus.train(x, index, weights=weights)
        cdis2, perm2 = index.search(ccent, 1)

        # distance^2 of ground-truth centroids to clusters
        cdis2_first = cdis2[:5].sum()
        cdis2_last = cdis2[5:].sum()

        print(cdis1_first, cdis1_last)
        print(cdis2_first, cdis2_last)

        # with the new clustering, the last should be much (*2) closer
        # to their centroids
        self.assertGreater(cdis1_last, cdis1_first * 2)
        self.assertGreater(cdis2_first, cdis2_last * 2)
Beispiel #10
0
    def test_white(self):

        # generate data
        d = 4
        nt = 1000
        nb = 200
        nq = 200

        # normal distribition
        x = faiss.randn((nt + nb + nq) * d, 1234).reshape(nt + nb + nq, d)

        index = faiss.index_factory(d, 'Flat')

        xt = x[:nt]
        xb = x[nt:-nq]
        xq = x[-nq:]

        # NN search on normal distribution
        index.add(xb)
        Do, Io = index.search(xq, 5)

        # make distribution very skewed
        x *= [10, 4, 1, 0.5]
        rr, _ = np.linalg.qr(faiss.randn(d * d).reshape(d, d))
        x = np.dot(x, rr).astype('float32')

        xt = x[:nt]
        xb = x[nt:-nq]
        xq = x[-nq:]

        # L2 search on skewed distribution
        index = faiss.index_factory(d, 'Flat')

        index.add(xb)
        Dl2, Il2 = index.search(xq, 5)

        # whiten + L2 search on L2 distribution
        index = faiss.index_factory(d, 'PCAW%d,Flat' % d)

        index.train(xt)
        index.add(xb)
        Dw, Iw = index.search(xq, 5)

        # make sure correlation of whitened results with original
        # results is much better than simple L2 distances
        # should be 961 vs. 264
        assert (faiss.eval_intersection(Io, Iw) >
                2 * faiss.eval_intersection(Io, Il2))
Beispiel #11
0
    def test_white(self):

        # generate data
        d = 4
        nt = 1000
        nb = 200
        nq = 200

        # normal distribition
        x = faiss.randn((nt + nb + nq) * d, 1234).reshape(nt + nb + nq, d)

        index = faiss.index_factory(d, 'Flat')

        xt = x[:nt]
        xb = x[nt:-nq]
        xq = x[-nq:]

        # NN search on normal distribution
        index.add(xb)
        Do, Io = index.search(xq, 5)

        # make distribution very skewed
        x *= [10, 4, 1, 0.5]
        rr, _ = np.linalg.qr(faiss.randn(d * d).reshape(d, d))
        x = np.dot(x, rr).astype('float32')

        xt = x[:nt]
        xb = x[nt:-nq]
        xq = x[-nq:]

        # L2 search on skewed distribution
        index = faiss.index_factory(d, 'Flat')

        index.add(xb)
        Dl2, Il2 = index.search(xq, 5)

        # whiten + L2 search on L2 distribution
        index = faiss.index_factory(d, 'PCAW%d,Flat' % d)

        index.train(xt)
        index.add(xb)
        Dw, Iw = index.search(xq, 5)

        # make sure correlation of whitened results with original
        # results is much better than simple L2 distances
        # should be 961 vs. 264
        assert (faiss.eval_intersection(Io, Iw) >
                2 * faiss.eval_intersection(Io, Il2))
    def test_indexflat(self):
        index = faiss.IndexFlatL2(32)
        x = faiss.randn((100, 32), 1234)
        index.add(x)

        subset = [4, 7, 45]
        np.testing.assert_equal(x[subset], index.reconstruct_batch(subset))
Beispiel #13
0
    def test_chain(self):

        # generate data
        d = 4
        nt = 1000
        nb = 200
        nq = 200

        # normal distribition
        x = faiss.randn((nt + nb + nq) * d, 1234).reshape(nt + nb + nq, d)

        # make distribution very skewed
        x *= [10, 4, 1, 0.5]
        rr, _ = np.linalg.qr(faiss.randn(d * d).reshape(d, d))
        x = np.dot(x, rr).astype('float32')

        xt = x[:nt]
        xb = x[nt:-nq]
        xq = x[-nq:]

        index = faiss.index_factory(d, "L2norm,PCA2,L2norm,Flat")

        assert index.chain.size() == 3
        l2_1 = faiss.downcast_VectorTransform(index.chain.at(0))
        assert l2_1.norm == 2
        pca = faiss.downcast_VectorTransform(index.chain.at(1))
        assert not pca.is_trained
        index.train(xt)
        assert pca.is_trained

        index.add(xb)
        D, I = index.search(xq, 5)

        # do the computation manually and check if we get the same result
        def manual_trans(x):
            x = x.copy()
            faiss.normalize_L2(x)
            x = pca.apply_py(x)
            faiss.normalize_L2(x)
            return x

        index2 = faiss.IndexFlatL2(2)
        index2.add(manual_trans(xb))
        D2, I2 = index2.search(manual_trans(xq), 5)

        assert np.all(I == I2)
Beispiel #14
0
    def test_chain(self):

        # generate data
        d = 4
        nt = 1000
        nb = 200
        nq = 200

        # normal distribition
        x = faiss.randn((nt + nb + nq) * d, 1234).reshape(nt + nb + nq, d)

        # make distribution very skewed
        x *= [10, 4, 1, 0.5]
        rr, _ = np.linalg.qr(faiss.randn(d * d).reshape(d, d))
        x = np.dot(x, rr).astype('float32')

        xt = x[:nt]
        xb = x[nt:-nq]
        xq = x[-nq:]

        index = faiss.index_factory(d, "L2norm,PCA2,L2norm,Flat")

        assert index.chain.size() == 3
        l2_1 = faiss.downcast_VectorTransform(index.chain.at(0))
        assert l2_1.norm == 2
        pca = faiss.downcast_VectorTransform(index.chain.at(1))
        assert not pca.is_trained
        index.train(xt)
        assert pca.is_trained

        index.add(xb)
        D, I = index.search(xq, 5)

        # do the computation manually and check if we get the same result
        def manual_trans(x):
            x = x.copy()
            faiss.normalize_L2(x)
            x = pca.apply_py(x)
            faiss.normalize_L2(x)
            return x

        index2 = faiss.IndexFlatL2(2)
        index2.add(manual_trans(xb))
        D2, I2 = index2.search(manual_trans(xq), 5)

        assert np.all(I == I2)
Beispiel #15
0
    def dump_load_factory(self, fs):
        xq = faiss.randn((25, 10), 123)
        xb = faiss.randn((25, 10), 124)

        index = faiss.index_factory(10, fs)
        index.train(xb)
        index.add(xb)
        Dref, Iref = index.search(xq, 4)

        buf = io.BytesIO()
        pickle.dump(index, buf)
        buf.seek(0)
        index2 = pickle.load(buf)

        Dnew, Inew = index2.search(xq, 4)

        np.testing.assert_array_equal(Iref, Inew)
        np.testing.assert_array_equal(Dref, Dnew)
Beispiel #16
0
    def test_raw_array_search(self):
        d = 32
        nb = 1024
        nq = 128
        k = 10

        # make GT on Faiss CPU

        xq = faiss.randn(nq * d, 1234).reshape(nq, d)
        xb = faiss.randn(nb * d, 1235).reshape(nb, d)

        index = faiss.IndexFlatL2(d)
        index.add(xb)
        gt_D, gt_I = index.search(xq, k)

        # move to pytorch & GPU
        xq_t = torch.from_numpy(xq).cuda()
        xb_t = torch.from_numpy(xb).cuda()

        # resource object, can be re-used over calls
        res = faiss.StandardGpuResources()

        # put on same stream as pytorch to avoid synchronizing streams
        res.setDefaultNullStreamAllDevices()

        D, I = search_raw_array_pytorch(res, xb_t, xq_t, k)

        # back to CPU for verification
        D = D.cpu().numpy()
        I = I.cpu().numpy()

        assert np.all(I == gt_I)
        assert np.all(np.abs(D - gt_D).max() < 1e-4)

        # test on subset
        D, I = search_raw_array_pytorch(res, xb_t, xq_t[60:80], k)

        # back to CPU for verification
        D = D.cpu().numpy()
        I = I.cpu().numpy()

        assert np.all(I == gt_I[60:80])
        assert np.all(np.abs(D - gt_D[60:80]).max() < 1e-4)
Beispiel #17
0
    def do_test(self, d, dsub, nbit=8, metric=None):
        if metric is None:
            self.do_test(d, dsub, nbit, faiss.METRIC_INNER_PRODUCT)
            self.do_test(d, dsub, nbit, faiss.METRIC_L2)
            return

        M = d // dsub
        pq = faiss.ProductQuantizer(d, M, nbit)
        pq.train(faiss.randn((max(1000, pq.ksub * 50), d), 123))

        centroids = faiss.vector_to_array(pq.centroids)
        centroids = centroids.reshape(pq.M, pq.ksub, pq.dsub)

        nx = 100
        x = faiss.randn((nx, d), 555)

        ref_tab = np.zeros((nx, M, pq.ksub), "float32")

        # computation of tables in numpy
        for sq in range(M):
            i0, i1 = sq * dsub, (sq + 1) * dsub
            xsub = x[:, i0:i1]
            centsq = centroids[sq, :, :]
            if metric == faiss.METRIC_INNER_PRODUCT:
                ref_tab[:, sq, :] = xsub @ centsq.T
            elif metric == faiss.METRIC_L2:
                xsub3 = xsub.reshape(nx, 1, dsub)
                cent3 = centsq.reshape(1, pq.ksub, dsub)
                ref_tab[:, sq, :] = ((xsub3 - cent3)**2).sum(2)
            else:
                assert False

        sp = faiss.swig_ptr

        new_tab = np.zeros((nx, M, pq.ksub), "float32")
        if metric == faiss.METRIC_INNER_PRODUCT:
            pq.compute_inner_prod_tables(nx, sp(x), sp(new_tab))
        elif metric == faiss.METRIC_L2:
            pq.compute_distance_tables(nx, sp(x), sp(new_tab))
        else:
            assert False

        np.testing.assert_array_almost_equal(ref_tab, new_tab, decimal=5)
    def test_exception(self):
        index = faiss.index_factory(32, "IVF2,Flat")
        x = faiss.randn((100, 32), 1234)
        index.train(x)
        index.add(x)

        # make sure it raises an exception even if it enters the openmp for
        subset = np.zeros(1200, dtype=int)
        self.assertRaises(
            RuntimeError,
            lambda : index.reconstruct_batch(subset),
        )
Beispiel #19
0
def run_bench(d, dsub, nbit=8, metric=None):

    M = d // dsub
    pq = faiss.ProductQuantizer(d, M, nbit)
    pq.train(faiss.randn((max(1000, pq.ksub * 50), d), 123))


    sp = faiss.swig_ptr

    times = []
    nrun = 100

    print(f"d={d} dsub={dsub} ksub={pq.ksub}", end="\t")
    res = []
    for nx in 1, 10, 100:
        x = faiss.randn((nx, d), 555)

        times = []
        for run in range(nrun):
            t0 = time.time()
            new_tab = np.zeros((nx, M, pq.ksub), "float32")
            if metric == faiss.METRIC_INNER_PRODUCT:
                pq.compute_inner_prod_tables(nx, sp(x), sp(new_tab))
            elif metric == faiss.METRIC_L2:
                pq.compute_distance_tables(nx, sp(x), sp(new_tab))
            else:
                assert False
            t1 = time.time()
            if run >= nrun // 5: # the rest is considered warmup
                times.append((t1 - t0))
        times = np.array(times) * 1000

        print(f"nx={nx}: {np.mean(times):.3f} ms (± {np.std(times):.4f})",
               end="\t")
        res.append(times.mean())
    print()
    return res
"""small test script to benchmark the SIMD implementation of the
distance computations for the additional metrics. Call eg. with L1 to
get L1 distance computations.
"""

import faiss

import sys
import time

d = 64
nq = 4096
nb = 16384

print("sample")

xq = faiss.randn((nq, d), 123)
xb = faiss.randn((nb, d), 123)

mt_name = "L2" if len(sys.argv) < 2 else sys.argv[1]

mt = getattr(faiss, "METRIC_" + mt_name)

print("distances")
t0 = time.time()
dis = faiss.pairwise_distances(xq, xb, mt)
t1 = time.time()

print("nq=%d nb=%d d=%d %s: %.3f s" % (nq, nb, d, mt_name, t1 - t0))
Beispiel #21
0
def random_unitary(n, d, seed):
    x = faiss.randn(n * d, seed).reshape(n, d)
    faiss.normalize_L2(x)
    return x
Beispiel #22
0
    def test_raw_array_search(self):
        d = 32
        nb = 1024
        nq = 128
        k = 10

        # make GT on Faiss CPU

        xq = faiss.randn(nq * d, 1234).reshape(nq, d)
        xb = faiss.randn(nb * d, 1235).reshape(nb, d)

        index = faiss.IndexFlatL2(d)
        index.add(xb)
        gt_D, gt_I = index.search(xq, k)

        # resource object, can be re-used over calls
        res = faiss.StandardGpuResources()

        # Let's have pytorch use a non-default stream
        s = torch.cuda.Stream()
        with torch.cuda.stream(s):
            for xq_row_major in True, False:
                for xb_row_major in True, False:

                    # move to pytorch & GPU
                    xq_t = torch.from_numpy(xq).cuda()
                    xb_t = torch.from_numpy(xb).cuda()

                    if not xq_row_major:
                        xq_t = to_column_major(xq_t)
                        assert not xq_t.is_contiguous()

                    if not xb_row_major:
                        xb_t = to_column_major(xb_t)
                        assert not xb_t.is_contiguous()

                    D, I = search_raw_array_pytorch(res, xb_t, xq_t, k)

                    # back to CPU for verification
                    D = D.cpu().numpy()
                    I = I.cpu().numpy()

                    assert np.all(I == gt_I)
                    assert np.all(np.abs(D - gt_D).max() < 1e-4)

                    # test on subset
                    try:
                        # This internally uses the current pytorch stream
                        D, I = search_raw_array_pytorch(
                            res, xb_t, xq_t[60:80], k)
                    except TypeError:
                        if not xq_row_major:
                            # then it is expected
                            continue
                        # otherwise it is an error
                        raise

                    # back to CPU for verification
                    D = D.cpu().numpy()
                    I = I.cpu().numpy()

                    assert np.all(I == gt_I[60:80])
                    assert np.all(np.abs(D - gt_D[60:80]).max() < 1e-4)