def get_data(self): if isinstance(self.data_source, str): if self.args.lazy_loading: data = get_chunk(read_jsonline_lazy(self.data_source), self.batch_size) else: total_data = read_jsonline(self.data_source) random.shuffle(total_data) random.shuffle(total_data) random.shuffle(total_data) data = get_chunk(total_data, self.batch_size) elif isinstance(self.data_source, list): random.shuffle(self.data_source) random.shuffle(self.data_source) random.shuffle(self.data_source) data = get_chunk(self.data_source, self.batch_size) elif isinstance(self.data_source, dict): golden_filename = self.data_source['golden_filename'] desc_id2item = {} for item in read_jsonline_lazy(golden_filename): desc_id2item[item['description_id']] = item search_filename = self.data_source['search_filename'] topk = self.data_source['topk'] searched_ids = set(self.data_source['searched_id_list']) def build_batch(search_item): qd_pairs = [] desc_id = search_item['description_id'] if desc_id in searched_ids: return [[]] query_text = desc_id2item[desc_id][self.args.query_field] if self.args.rerank_model_name == 'pairwise': docs = search_item['docs'][:topk] for i, doc_id in enumerate(docs): for p_doc_id in docs[:i] + docs[i + 1:]: raw_item = { 'description_id': desc_id, 'query': query_text, 'first_doc_id': doc_id, 'second_doc_id': p_doc_id } qd_pairs.append(raw_item) else: for doc_id in search_item['docs'][:topk]: raw_item = { 'description_id': desc_id, 'query': query_text, 'doc_id': doc_id } qd_pairs.append(raw_item) return get_chunk(qd_pairs, self.batch_size) data = map(build_batch, read_jsonline_lazy(search_filename)) data = chain.from_iterable(data) else: raise ValueError('data type error') return data
def predict(self, src_filename, dest_filename): self.model.eval() existed_ids = set() for item in read_jsonline_lazy(dest_filename, default=[]): existed_ids.add(item[self.id_field]) loader = VectorizationDataLoader(src_filename, self.tokenizer, self.args) cos = nn.CosineSimilarity(dim=1) tp_count = 0 total_count = 0 for batch in loader: with torch.no_grad(): query_embed = self.model(batch, 'query') true_embed = self.model(batch, 'true') false_embed = self.model(batch, 'false') true_scores = cos(query_embed, true_embed) false_scores = cos(query_embed, false_embed) print(true_scores, false_scores) total_count += query_embed.size(0) tp_count += (true_scores > false_scores).sum().cpu().numpy().tolist() accuray = tp_count / total_count return accuray
def rerank(self, src_filename, dest_filename, vector_filename, topk=100): desc_id2vector = VectorIndexer(vector_filename).load_vector() for item in read_jsonline_lazy(src_filename): desc_id = item['description_id'] paper_id_list = item['docs'][:topk] if desc_id in desc_id2vector: desc_vector = desc_id2vector[desc_id] paper_id_with_scores = [] for paper_id in paper_id_list: if paper_id not in self.paper_id2vector: score = 0 else: paper_vector = self.paper_id2vector[paper_id] dot_score = np.dot(paper_vector, desc_vector) norm_score = np.linalg.norm( paper_vector) * np.linalg.norm(desc_vector) score = dot_score / norm_score paper_id_with_scores.append((paper_id, score)) sorted_id_scores = sorted(paper_id_with_scores, key=lambda i: (i[1], i[0]), reverse=True) reranked_paper_id_list = [p for p, s in sorted_id_scores] else: reranked_paper_id_list = paper_id_list result_item = { 'description_id': desc_id, 'docs': reranked_paper_id_list } append_jsonline(dest_filename, result_item)
def build_data(self): pool = Pool(20) desc_id2item = {} for item in read_jsonline_lazy(self.golden_filename): desc_id = item['description_id'] desc_id2item[desc_id] = item chunk_size = 50 for item_chunk in get_chunk(read_jsonline_lazy(self.search_filename), chunk_size): new_item_chunk = [] for item in item_chunk: true_item = desc_id2item[item['description_id']] true_paper_id = true_item['paper_id'] cites_text = true_item['cites_text'] docs = item['docs'] item.pop('docs') item.pop('keywords') new_item_list = [] new_item_dict = copy.deepcopy(item) new_item_dict['true_paper_id'] = true_paper_id new_item_dict['false_paper_id'] = [] new_item_dict['cites_text'] = cites_text new_item_dict['description_text'] = true_item[ 'description_text'] for idx in range(self.args.sample_count): train_pair = self.select_train_pair( docs, true_paper_id, self.args.select_strategy, idx) new_item = { **train_pair, **item, 'cites_text': cites_text, 'description_text': true_item['description_text'] } new_item_list.append(new_item) new_item_dict['false_paper_id'].append( train_pair['false_paper_id']) if self.args.aggregate_sample: new_item_chunk.append(new_item_dict) else: new_item_chunk.extend(new_item_list) built_items = pool.map(self.build_single_query, new_item_chunk) built_items = [i for i in built_items if i] append_jsonlines(self.dest_filename, built_items)
def result_format(src_filename, dest_filename=None): if dest_filename is None: dest_filename = os.path.splitext(src_filename)[0] + '.csv' for item in read_jsonline_lazy(src_filename): desc_id = item['description_id'] paper_ids = item['docs'][:3] if not paper_ids: raise ValueError('result is empty') line = desc_id + '\t' + '\t'.join(paper_ids) append_line(dest_filename, line)
def view(self): for idx, item in enumerate(read_jsonline_lazy(DATA_DIR + 'test.jsonl')): if idx > 20: break paper = get_paper(item['paper_id']) print('============') print(item['description_text']) print('-------------') print(paper['title']) print(paper['abstract'])
def load_data(self, chunk_size): if isinstance(self.data_source, str): if self.lazy_loading: data = read_jsonline_lazy(self.data_source) else: data = read_jsonline(self.data_source) if self.loader.shuffle: random.shuffle(data) random.shuffle(data) random.shuffle(data) elif isinstance(self.data_source, list): data = iter(self.data_source) else: raise TypeError('input filename type is error') return get_chunk(data, chunk_size)
def process(self): pool = Pool(self.parallel_count) tokens = [] chunk_size = 100 for item_chunk in get_chunk(read_jsonline_lazy(self.src_filename), chunk_size): processed_records = pool.map(self.tokenize_record, item_chunk) if self.dest_vocab_path: for record in processed_records: tokens.extend(record['title_and_abstract_tokens'] + record['flatten_keyword_tokens']) for record in processed_records: record.pop('flatten_keyword_tokens') append_jsonlines(self.dest_filename, processed_records) if self.dest_vocab_path: vocab = self.build_vocab(tokens) write_lines(self.dest_vocab_path, vocab)
def indexing_runner(self): filename = CANDIDATE_FILENAME pool = Pool(self.parallel_size) start = time.time() count = 0 failed_doc_list = [] for item_chunk in get_chunk(read_jsonline_lazy(filename), 500): ret = pool.map(self.index_doc, item_chunk) failed_doc_list.extend([i for i in ret if i]) duration = time.time() - start count += len(item_chunk) msg = '{} completed, {}min {:.2f}s'.format(count, duration // 60, duration % 60) self.logger.info(msg) for doc in failed_doc_list: self.index_doc(doc)
def get_searched_desc_id(self, filename): desc_ids = {item['description_id'] for item in read_jsonline_lazy(filename, default=[])} return desc_ids