-
Notifications
You must be signed in to change notification settings - Fork 1
/
FilterEmbeddings.py
60 lines (46 loc) · 2.3 KB
/
FilterEmbeddings.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
import os
import nmslib
import pandas as pd
class FilterEmbeddings:
def __init__(self, params):
self._vecs_file = os.path.abspath(params['VECS_FILE'])
self._meta_file = os.path.abspath(params['NAME_FILE'])
self._vector_space = self._create_vector_space(self._vecs_file)
print(self._vector_space)
embedding_file = os.path.abspath(params['EMBEDDING_FILE'])
self._filter_embeddings = self._read_tsv_file(embedding_file)
def _create_vector_space(self, file_path):
vector_data = self._read_tsv_file(file_path)
vector_space = nmslib.init(method='hnsw', space='cosinesimil')
vector_space.addDataPointBatch(vector_data)
vector_space.createIndex({'post': 2}, print_progress=True)
return vector_space
def _read_tsv_file(self, tsv_file):
df = pd.read_csv(tsv_file, sep='\t', header=None)
vector_data = df.values
return vector_data
def filter_data(self, nearest_point):
# get all nearest neighbours for all the datapoint
# using a pool of 4 threads to compute
neighbours = self._vector_space.knnQueryBatch(self._filter_embeddings, k=nearest_point, num_threads=4)
filtered_id_set = set()
for vector_ids, _ in neighbours:
filtered_id_set.update(vector_ids)
filtered_ids = list(filtered_id_set)
filtered_ids.sort()
self._filter_rows_and_store(self._vecs_file, filtered_ids)
self._filter_rows_and_store(self._meta_file, filtered_ids)
def _filter_rows_and_store(self, file_path, filter_rows):
df = pd.read_csv(file_path, sep='\t', header=None)
df = df.iloc[filter_rows, :]
save_file_path = os.path.abspath('filtered_' + os.path.basename(file_path))
df.to_csv(save_file_path, index=False)
print('Filtered data has been saved at ', save_file_path)
if __name__ == '__main__':
params = {
'VECS_FILE': 'vecs_tf1.tsv', # Generic embeddings out of all your sentences generated from PDF files
'NAME_FILE': 'meta_tf1.tsv', # Meta file of your generic embeddings
'EMBEDDING_FILE': 'custom_vecs_tf1.tsv' # Embeddings generated out of ideal sample statements
}
fe = FilterEmbeddings(params)
fe.filter_data(nearest_point=500)