def test_constructor(self): d = 32 ds = datasets.SyntheticDataset(d, 2000, 5000, 200) index = faiss.index_factory(d, f'PQ{d//2}x4np') index.train(ds.get_train()) index.add(ds.get_database()) Dref, Iref = index.search(ds.get_queries(), 10) nq = Iref.shape[0] index2 = faiss.IndexPQFastScan(d, d // 2, 4) index2.train(ds.get_train()) index2.add(ds.get_database()) Dnew, Inew = index2.search(ds.get_queries(), 10) recall_at_1 = (Iref[:, 0] == Inew[:, 0]).sum() / nq self.assertGreater(recall_at_1, 0.99) data = faiss.serialize_index(index2) index3 = faiss.deserialize_index(data) self.assertEqual(index2.implem, index3.implem) D3, I3 = index3.search(ds.get_queries(), 10) np.testing.assert_array_equal(D3, Dnew) np.testing.assert_array_equal(I3, Inew)
def test_IndexLocalSearchQuantizer(self): ds = datasets.SyntheticDataset(32, 1000, 200, 100) gt = ds.get_groundtruth(10) ir = faiss.IndexLocalSearchQuantizer(ds.d, 4, 5) ir.train(ds.get_train()) ir.add(ds.get_database()) Dref, Iref = ir.search(ds.get_queries(), 10) inter_ref = faiss.eval_intersection(Iref, gt) # 467 self.assertGreater(inter_ref, 460) AQ = faiss.AdditiveQuantizer ir2 = faiss.IndexLocalSearchQuantizer(ds.d, 4, 5, faiss.METRIC_L2, AQ.ST_norm_float) ir2.train(ds.get_train()) # just to set flags properly ir2.lsq.codebooks = ir.lsq.codebooks ir2.add(ds.get_database()) D2, I2 = ir2.search(ds.get_queries(), 10) np.testing.assert_array_almost_equal(Dref, D2, decimal=5) self.assertLess((Iref != I2).sum(), Iref.size * 0.01) # test I/O ir3 = faiss.deserialize_index(faiss.serialize_index(ir)) D3, I3 = ir3.search(ds.get_queries(), 10) np.testing.assert_array_equal(Iref, I3) np.testing.assert_array_equal(Dref, D3)
def do_write_callback(self, bsz): d, n = 32, 1000 x = np.random.uniform(size=(n, d)).astype('float32') index = faiss.IndexFlatL2(d) index.add(x) f = io.BytesIO() # test with small block size writer = faiss.PyCallbackIOWriter(f.write, 1234) if bsz > 0: writer = faiss.BufferedIOWriter(writer, bsz) faiss.write_index(index, writer) del writer # make sure all writes committed if sys.version_info[0] < 3: buf = f.getvalue() else: buf = f.getbuffer() index2 = faiss.deserialize_index(np.frombuffer(buf, dtype='uint8')) self.assertEqual(index.d, index2.d) self.assertTrue( np.all( faiss.vector_to_array(index.xb) == faiss.vector_to_array( index2.xb))) # This is not a callable function: shoudl raise an exception writer = faiss.PyCallbackIOWriter("blabla") self.assertRaises(Exception, faiss.write_index, index, writer)
def test_pca_epsilon(self): d = 64 n = 1000 np.random.seed(123) x = np.random.random(size=(n, d)).astype('float32') # make sure data is in a sub-space x[:, ::2] = 0 # check division by 0 with default computation pca = faiss.PCAMatrix(d, 60, -0.5) pca.train(x) y = pca.apply(x) self.assertFalse(np.all(np.isfinite(y))) # check add epsilon pca = faiss.PCAMatrix(d, 60, -0.5) pca.epsilon = 1e-5 pca.train(x) y = pca.apply(x) self.assertTrue(np.all(np.isfinite(y))) # check I/O index = faiss.index_factory(d, "PCAW60,Flat") index = faiss.deserialize_index(faiss.serialize_index(index)) pca1 = faiss.downcast_VectorTransform(index.chain.at(0)) pca1.epsilon = 1e-5 index.train(x) pca = faiss.downcast_VectorTransform(index.chain.at(0)) y = pca.apply(x) self.assertTrue(np.all(np.isfinite(y)))
def eval_index_accuracy(self, factory_key): # just do a single test, most search functions are already stress # tested in test_residual_quantizer.py ds = datasets.SyntheticDataset(32, 3000, 1000, 100) index = faiss.index_factory(ds.d, factory_key) index.train(ds.get_train()) index.add(ds.get_database()) inters = [] for nprobe in 1, 2, 5, 10, 20, 50: index.nprobe = nprobe D, I = index.search(ds.get_queries(), 10) inter = faiss.eval_intersection(I, ds.get_groundtruth(10)) # print("nprobe=", nprobe, "inter=", inter) inters.append(inter) inters = np.array(inters) # in fact the results should be the same for the decoding and the # reconstructing versions self.assertTrue(np.all(inters[1:] >= inters[:-1])) # do a little I/O test index2 = faiss.deserialize_index(faiss.serialize_index(index)) D2, I2 = index2.search(ds.get_queries(), 10) np.testing.assert_array_equal(I2, I) np.testing.assert_array_equal(D2, D)
def compare_accuracy(self, lowac, highac, max_errs=(1e10, 1e10)): d = 96 nb = 1000 nq = 0 nt = 2000 xt, x, _ = get_dataset_2(d, nt, nb, nq) errs = [] for factory_string in lowac, highac: codec = faiss.index_factory(d, factory_string) print('sa codec: code size %d' % codec.sa_code_size()) codec.train(xt) codes = codec.sa_encode(x) x2 = codec.sa_decode(codes) err = ((x - x2)**2).sum() errs.append(err) print(errs) self.assertGreater(errs[0], errs[1]) self.assertGreater(max_errs[0], errs[0]) self.assertGreater(max_errs[1], errs[1]) # just a small IndexLattice I/O test if 'Lattice' in highac: codec2 = faiss.deserialize_index(faiss.serialize_index(codec)) codes = codec.sa_encode(x) x3 = codec.sa_decode(codes) self.assertTrue(np.all(x2 == x3))
def do_test(self, index_key): d = 32 index = faiss.index_factory(d, index_key) index.train(faiss.randn((100, d), 123)) # reference reconstruction index.add(faiss.randn((100, d), 345)) index.add(faiss.randn((100, d), 678)) ref_recons = index.reconstruct_n(0, 200) # with lookup index.reset() rs = np.random.RandomState(123) ids = rs.choice(10000, size=200, replace=False).astype(np.int64) index.add_with_ids(faiss.randn((100, d), 345), ids[:100]) index.set_direct_map_type(faiss.DirectMap.Hashtable) index.add_with_ids(faiss.randn((100, d), 678), ids[100:]) # compare for i in range(0, 200, 13): recons = index.reconstruct(int(ids[i])) self.assertTrue(np.all(recons == ref_recons[i])) # test I/O buf = faiss.serialize_index(index) index2 = faiss.deserialize_index(buf) # compare for i in range(0, 200, 13): recons = index2.reconstruct(int(ids[i])) self.assertTrue(np.all(recons == ref_recons[i])) # remove toremove = np.ascontiguousarray(ids[0:200:3]) sel = faiss.IDSelectorArray(50, faiss.swig_ptr(toremove[:50])) # test both ways of removing elements nremove = index2.remove_ids(sel) nremove += index2.remove_ids(toremove[50:]) self.assertEqual(nremove, len(toremove)) for i in range(0, 200, 13): if i % 3 == 0: self.assertRaises( RuntimeError, index2.reconstruct, int(ids[i]) ) else: recons = index2.reconstruct(int(ids[i])) self.assertTrue(np.all(recons == ref_recons[i])) # index error should raise self.assertRaises( RuntimeError, index.reconstruct, 20000 )
def do_test_accuracy(self, by_residual, st): ds = datasets.SyntheticDataset(32, 3000, 1000, 100) quantizer = faiss.IndexFlatL2(ds.d) index = faiss.IndexIVFResidualQuantizer( quantizer, ds.d, 100, 3, 4, faiss.METRIC_L2, st ) index.by_residual = by_residual index.rq.train_type index.rq.train_type = faiss.ResidualQuantizer.Train_default index.rq.max_beam_size = 30 index.train(ds.get_train()) index.add(ds.get_database()) inters = [] for nprobe in 1, 2, 5, 10, 20, 50: index.nprobe = nprobe D, I = index.search(ds.get_queries(), 10) inter = faiss.eval_intersection(I, ds.get_groundtruth(10)) # print(st, "nprobe=", nprobe, "inter=", inter) inters.append(inter) # do a little I/O test index2 = faiss.deserialize_index(faiss.serialize_index(index)) D2, I2 = index2.search(ds.get_queries(), 10) np.testing.assert_array_equal(I2, I) np.testing.assert_array_equal(D2, D) inters = np.array(inters) if by_residual: # check that we have increasing intersection measures with # nprobe self.assertTrue(np.all(inters[1:] >= inters[:-1])) else: self.assertTrue(np.all(inters[1:3] >= inters[:2])) # check that we have the same result as the flat residual quantizer iflat = faiss.IndexResidualQuantizer( ds.d, 3, 4, faiss.METRIC_L2, st) iflat.rq.train_type iflat.rq.train_type = faiss.ResidualQuantizer.Train_default iflat.rq.max_beam_size = 30 iflat.train(ds.get_train()) iflat.rq.codebooks = index.rq.codebooks iflat.add(ds.get_database()) Dref, Iref = iflat.search(ds.get_queries(), 10) index.nprobe = 100 D2, I2 = index.search(ds.get_queries(), 10) np.testing.assert_array_almost_equal(Dref, D2, decimal=5) # there are many ties because the codes are so short self.assertLess((Iref != I2).sum(), Iref.size * 0.2)
def load_faiss_index(path_to_faiss="models/faiss_index.pickle"): """Load and deserialize the Faiss index.""" # Download from Google id_data = "1AFNS2rdO4_x_XzKa4nAcMtwreODoeF_P" url = 'https://drive.google.com/uc?id=' + id_data gdown.download(url, path_to_faiss, quiet=False) with open(path_to_faiss, "rb") as h: data = pickle.load(h) return faiss.deserialize_index(data)
def test_downcast_Refine(self): index = faiss.IndexRefineFlat( faiss.IndexScalarQuantizer(10, faiss.ScalarQuantizer.QT_8bit)) # serialize and deserialize index2 = faiss.deserialize_index(faiss.serialize_index(index)) assert isinstance(index2, faiss.IndexRefineFlat)
def test_factory(self): ds = datasets.SyntheticDataset(16, 500, 1000, 100) index = faiss.index_factory(ds.d, "IVF1024(RCQ2x5),Flat") index.train(ds.get_train()) index.add(ds.get_database()) Dref, Iref = index.search(ds.get_queries(), 10) b = faiss.serialize_index(index) index2 = faiss.deserialize_index(b) Dnew, Inew = index2.search(ds.get_queries(), 10) np.testing.assert_equal(Dref, Dnew) np.testing.assert_equal(Iref, Inew)
def test_io(self): ds = datasets.SyntheticDataset(32, 1000, 100, 0) xt = ds.get_train() xb = ds.get_database() ir = faiss.IndexResidualQuantizer(ds.d, 3, 4) ir.rq.train_type = faiss.ResidualQuantizer.Train_default ir.train(xt) ref_codes = ir.sa_encode(xb) b = faiss.serialize_index(ir) ir2 = faiss.deserialize_index(b) codes2 = ir2.sa_encode(xb) np.testing.assert_array_equal(ref_codes, codes2)
def test_serialize(self): res = faiss.StandardGpuResources() d = 32 k = 10 train = make_t(10000, d) add = make_t(10000, d) query = make_t(10, d) # Construct various GPU index types indexes = [] # Flat indexes.append(faiss.GpuIndexFlatL2(res, d)) # IVF nlist = 5 # IVFFlat indexes.append(faiss.GpuIndexIVFFlat(res, d, nlist, faiss.METRIC_L2)) # IVFSQ indexes.append(faiss.GpuIndexIVFScalarQuantizer(res, d, nlist, faiss.ScalarQuantizer.QT_fp16)) # IVFPQ indexes.append(faiss.GpuIndexIVFPQ(res, d, nlist, 4, 8, faiss.METRIC_L2)) for index in indexes: index.train(train) index.add(add) orig_d, orig_i = index.search(query, k) ser = faiss.serialize_index(faiss.index_gpu_to_cpu(index)) cpu_index = faiss.deserialize_index(ser) gpu_index_restore = faiss.index_cpu_to_gpu(res, 0, cpu_index) restore_d, restore_i = gpu_index_restore.search(query, k) self.assertTrue(np.array_equal(orig_d, restore_d)) self.assertTrue(np.array_equal(orig_i, restore_i)) # Make sure the index is in a state where we can add to it # without error gpu_index_restore.add(query)
def do_test(self, by_residual=False, metric=faiss.METRIC_L2, d=32, bbs=32): bbs = 32 ds = datasets.SyntheticDataset(d, 2000, 5000, 200) index = faiss.index_factory(d, f"IVF32,PQ{d//2}x4np", metric) index.by_residual = by_residual index.train(ds.get_train()) index.add(ds.get_database()) index.nprobe = 4 Dref, Iref = index.search(ds.get_queries(), 10) index2 = faiss.IndexIVFPQFastScan( index.quantizer, d, 32, d // 2, 4, metric, bbs) index2.by_residual = by_residual index2.train(ds.get_train()) index2.add(ds.get_database()) index2.nprobe = 4 Dnew, Inew = index2.search(ds.get_queries(), 10) m3 = three_metrics(Dref, Iref, Dnew, Inew) # print((by_residual, metric, d), ":", m3) ref_m3_tab = { (True, 1, 32) : (0.995, 1.0, 9.91), (True, 0, 32) : (0.99, 1.0, 9.91), (True, 1, 30) : (0.99, 1.0, 9.885), (False, 1, 32) : (0.99, 1.0, 9.875), (False, 0, 32) : (0.99, 1.0, 9.92), (False, 1, 30) : (1.0, 1.0, 9.895) } ref_m3 = ref_m3_tab[(by_residual, metric, d)] self.assertGreater(m3[0], ref_m3[0] * 0.99) self.assertGreater(m3[1], ref_m3[1] * 0.99) self.assertGreater(m3[2], ref_m3[2] * 0.99) # Test I/O data = faiss.serialize_index(index2) index3 = faiss.deserialize_index(data) D3, I3 = index3.search(ds.get_queries(), 10) np.testing.assert_array_equal(I3, Inew) np.testing.assert_array_equal(D3, Dnew)
def test_rcq_LUT(self): ds = datasets.SyntheticDataset(32, 3000, 1000, 100) xt = ds.get_train() xb = ds.get_database() # RQ 2x5 = 10 bits = 1024 centroids index = faiss.index_factory(ds.d, "IVF1024(RCQ2x5),SQ8") quantizer = faiss.downcast_index(index.quantizer) rq = quantizer.rq rq.train_type = faiss.ResidualQuantizer.Train_default index.train(xt) index.add(xb) index.nprobe = 10 # set exact centroids as coarse quantizer all_centroids = quantizer.reconstruct_n(0, quantizer.ntotal) q2 = faiss.IndexFlatL2(32) q2.add(all_centroids) index.quantizer = q2 Dref, Iref = index.search(ds.get_queries(), 10) index.quantizer = quantizer # search with LUT quantizer.set_beam_factor(-1) Dnew, Inew = index.search(ds.get_queries(), 10) np.testing.assert_array_almost_equal(Dref, Dnew, decimal=5) np.testing.assert_array_equal(Iref, Inew) # check i/o CDref, CIref = quantizer.search(ds.get_queries(), 10) quantizer2 = faiss.deserialize_index(faiss.serialize_index(quantizer)) quantizer2.search(ds.get_queries(), 10) CDnew, CInew = quantizer2.search(ds.get_queries(), 10) np.testing.assert_array_almost_equal(CDref, CDnew, decimal=5) np.testing.assert_array_equal(CIref, CInew)
def do_test(self, key1, key2): d = 96 nb = 1000 nq = 0 nt = 2000 xt, x, _ = get_dataset_2(d, nt, nb, nq) codec_ref = faiss.index_factory(d, key1) codec_ref.train(xt) code_ref = codec_ref.sa_encode(x) x_recons_ref = codec_ref.sa_decode(code_ref) codec_new = faiss.index_factory(d, key2) codec_new.pq = codec_ref.pq # replace quantizer, avoiding mem leak oldq = codec_new.q1.quantizer oldq.this.own() codec_new.q1.own_fields = False codec_new.q1.quantizer = codec_ref.quantizer codec_new.is_trained = True code_new = codec_new.sa_encode(x) x_recons_new = codec_new.sa_decode(code_new) self.assertTrue(np.all(code_new == code_ref)) self.assertTrue(np.all(x_recons_new == x_recons_ref)) codec_new_2 = faiss.deserialize_index( faiss.serialize_index(codec_new)) code_new = codec_new_2.sa_encode(x) x_recons_new = codec_new_2.sa_decode(code_new) self.assertTrue(np.all(code_new == code_ref)) self.assertTrue(np.all(x_recons_new == x_recons_ref))
def do_test_knn(self, mt): d = 10 nb = 100 nq = 50 nt = 0 xt, xb, xq = get_dataset_2(d, nt, nb, nq) index = faiss.IndexFlat(d, mt) index.add(xb) D, I = index.search(xq, 10) dis = faiss.pairwise_distances(xq, xb, mt) o = dis.argsort(axis=1) assert np.all(I == o[:, :10]) for q in range(nq): assert np.all(D[q] == dis[q, I[q]]) index2 = faiss.deserialize_index(faiss.serialize_index(index)) D2, I2 = index2.search(xq, 10) self.assertTrue(np.all(I == I2))
def main(): data = read_data() fos_level = unique_fos_level(data) model = load_bert_model() faiss_index = faiss.deserialize_index(load_faiss_index()) author_data = read_author_data() st.title("ACL Publications Explorer") filter_year = st.sidebar.slider("Filter by year", 2000, 2020, (2000, 2020), 1) filter_fos_level = st.sidebar.selectbox("Choose Field of Study level", fos_level) fields_of_study = unique_fos(data, filter_fos_level, 25) filter_fos = st.sidebar.multiselect("Choose Fields of Study", fields_of_study) author_input = st.sidebar.text_input("Search by author name") # User search user_input = st.sidebar.text_area("Search by paper title") num_results = st.sidebar.slider("Number of search results", 10, 150, 10) if filter_fos and not user_input and not author_input: frame = data[(data.name.isin(filter_fos)) & (data.year >= str(filter_year[0])) & (data.year <= str(filter_year[1]))] color_on_fos = True elif filter_fos and user_input and not author_input: encoded_user_input = vector_search([user_input], model, faiss_index, num_results) frame = data[(data.name.isin(filter_fos)) & (data.year >= str(filter_year[0])) & (data.year <= str(filter_year[1])) & (data.id.isin(encoded_user_input))] color_on_fos = True elif filter_fos and user_input and author_input: ids = author_data[author_data.name == author_input]['paper_id'] encoded_user_input = vector_search([user_input], model, faiss_index, num_results) frame = data[(data.name.isin(filter_fos)) & (data.year >= str(filter_year[0])) & (data.year <= str(filter_year[1])) & (data.id.isin(encoded_user_input)) & (data.id.isin(ids))] color_on_fos = True elif filter_fos and not user_input and author_input: ids = author_data[author_data.name == author_input]['paper_id'] frame = data[(data.name.isin(filter_fos)) & (data.year >= str(filter_year[0])) & (data.year <= str(filter_year[1])) & (data.id.isin(ids))] color_on_fos = True elif not filter_fos and user_input and not author_input: encoded_user_input = vector_search([user_input], model, faiss_index, num_results) frame = data[data.id.isin(encoded_user_input) & (data.year >= str(filter_year[0])) & (data.year <= str(filter_year[1]))] color_on_fos = False elif not filter_fos and user_input and author_input: encoded_user_input = vector_search([user_input], model, faiss_index, num_results=150) ids = author_data[author_data.name == author_input]['paper_id'] frame = data[(data.id.isin(ids)) & (data.year >= str(filter_year[0])) & (data.year <= str(filter_year[1])) & (data.id.isin(encoded_user_input))] color_on_fos = False elif not filter_fos and not user_input and author_input: ids = author_data[author_data.name == author_input]['paper_id'] frame = data[(data.id.isin(ids)) & (data.year >= str(filter_year[0])) & (data.year <= str(filter_year[1]))] color_on_fos = False else: frame = data[(data.year >= str(filter_year[0])) & (data.year <= str(filter_year[1]))] color_on_fos = False if color_on_fos: chart = alt.Chart(frame.drop_duplicates('id')).mark_point().encode( alt.X('Component 1', scale=alt.Scale(domain=(1, 16))), alt.Y('Component 2', scale=alt.Scale(domain=(0, 18))), alt.Color('name', title='Field of Study'), alt.Size('citations', scale=alt.Scale(range=[10, 500]), title='Citations'), href='source:N', tooltip=['title', 'year']).interactive().properties(width=650, height=500) else: chart = alt.Chart(frame.drop_duplicates('id')).mark_point().encode( alt.X('Component 1', scale=alt.Scale(domain=(1, 16))), alt.Y('Component 2', scale=alt.Scale(domain=(0, 18))), alt.Size('citations', scale=alt.Scale(range=[10, 500]), title='Citations'), href='source:N', tooltip=['title', 'year']).interactive().properties(width=650, height=500) bar_data = pd.DataFrame(frame[frame.level == filter_fos_level].groupby( 'name')['id'].count()).reset_index().sort_values('id', ascending=False)[:30] barchart = alt.Chart(bar_data).mark_bar().encode( alt.X('name', sort='-y', title='Fields of Study'), alt.Y('id', title='Count')).properties(width=650, height=150) c = (chart & barchart) st.altair_chart(c, use_container_width=True) st.subheader("How to use this app") st.write(f""" This application is intended for the visual exploration and discovery of research publications that have been presented at the ACL (Annual Meeting of the Association for Computational Linguistics). Every particle in the scatterplot is an academic publication. The particles are positioned in space based on the semantic similarity of the paper titles; the closer two points are, the more semantically similar their titles. You can hover over the particles to read their titles and you can click them to be redirected to the original source. You can zoom in the visualisation by scrolling and you can reset the view by double clicking the white space within the figure. Regarding the bar chart, it shows the most used Fields of Study for the papers shown in the scatterplot. You can also **search** for publications by paper titles (more information below). #### Filters You can refine your query based on the publication year, paper content, field of study and author. You can also combine any of the filter for more granular searches. - **Filter by year**: Select a time range for the papers. For example, drag both sliders to 2020 to find out the papers that will be presented at ACL 2020. - **Field of Study level**: Microsoft Academic Graph uses a 6-level hierarchy where level 0 contains high level disciplines such as Computer science and level 5 contains the most granular paper keywords. This filter will change what's shown in the bar chart as well as the available options in the filter below. - ** Fields of Study**: Select the Fields of Study to be displayed in the visualisations. The available options are affected by your selection in the above filter. - **Search by author name**: Find an author's publications. **Note**: You need to type in the exact name. - **Search by paper title**: Type in a paper title and find its most relevant relevant publications. You should use at least a sentence to receive meaningful results. - **Number of search results**: Specify the number of papers to be returned when you search by paper title. """) st.subheader("About") st.write(f""" I am [Kostas](http://kstathou.github.io/) and I work at the intersection of knowledge discovery, data engineering and scientometrics. I am a Mozilla Open Science Fellow and a Principal Data Science Researcher at Nesta. I am currently working on [Orion](https://orion-search.org/) (work in progress), an open-source knowledge discovery and research measurement tool. If you have any questions or would like to learn more about it, you can find me on [twitter](https://twitter.com/kstathou) or send me an email at [email protected] """) st.subheader("Appendix: Data & methods") st.write(f""" I collected all of the publications from [Microsoft Academic Graph](https://www.microsoft.com/en-us/research/project/academic-knowledge/) that were published between 2000 and 2020 and were presented at the ACL. I fetched 8,724 publications. To create the 2D visualisation, I encoded the paper titles to dense vectors using a [sentence-DistilBERT](https://github.com/UKPLab/sentence-transformers) model. That produced a 768-dimensional vector for each paper which I projected to a 2D space with [UMAP](https://umap-learn.readthedocs.io/en/latest/). For the paper title search engine, I indexed the vectors with [Faiss](https://github.com/facebookresearch/faiss/tree/master/python). """)
def load(self, path): with open(path, 'rb') as f: tmp_dict = pickle.load(f) tmp_dict['_index'] = faiss.deserialize_index(tmp_dict['_index']) self.__dict__.update(tmp_dict)
def save(self, path): self._index = faiss.serialize_index(self._index) with open(path, 'wb') as f: pickle.dump(self.__dict__, f) self._index = faiss.deserialize_index(self._index)
import faiss app = Flask(__name__) api = Api(app) # Parse arguments parser = reqparse.RequestParser() parser.add_argument("query") parser.add_argument("results") parser.add_argument("citation_count") # Load env load_dotenv(find_dotenv()) # VectorSimilarity models faiss_index = faiss.deserialize_index( load_from_s3(os.getenv("s3_bucket"), os.getenv("faiss_index"))) model = SentenceTransformer("distilbert-base-nli-stsb-mean-tokens") # ES setup es_port = os.getenv("es_port") es_host = os.getenv("es_host") region = os.getenv("region") es_index = os.getenv("es_index") es = aws_es_client(es_host, es_port, region) class VectorSimilarity(Resource): @cors.crossdomain(origin="*") def get(self): # Parse user's query
def load_faiss_index(path_to_faiss="faiss_indobert_docs/faiss_index.pickle"): """Load and deserialize the Faiss index.""" with open(path_to_faiss, "rb") as h: data = pickle.load(h) return faiss.deserialize_index(data)