示例#1
0
    def get_data(self):
        if isinstance(self.data_source, str):
            if self.args.lazy_loading:
                data = get_chunk(read_jsonline_lazy(self.data_source),
                                 self.batch_size)
            else:
                total_data = read_jsonline(self.data_source)
                random.shuffle(total_data)
                random.shuffle(total_data)
                random.shuffle(total_data)
                data = get_chunk(total_data, self.batch_size)
        elif isinstance(self.data_source, list):
            random.shuffle(self.data_source)
            random.shuffle(self.data_source)
            random.shuffle(self.data_source)
            data = get_chunk(self.data_source, self.batch_size)
        elif isinstance(self.data_source, dict):
            golden_filename = self.data_source['golden_filename']
            desc_id2item = {}
            for item in read_jsonline_lazy(golden_filename):
                desc_id2item[item['description_id']] = item
            search_filename = self.data_source['search_filename']
            topk = self.data_source['topk']
            searched_ids = set(self.data_source['searched_id_list'])

            def build_batch(search_item):
                qd_pairs = []
                desc_id = search_item['description_id']
                if desc_id in searched_ids:
                    return [[]]

                query_text = desc_id2item[desc_id][self.args.query_field]
                if self.args.rerank_model_name == 'pairwise':
                    docs = search_item['docs'][:topk]
                    for i, doc_id in enumerate(docs):
                        for p_doc_id in docs[:i] + docs[i + 1:]:
                            raw_item = {
                                'description_id': desc_id,
                                'query': query_text,
                                'first_doc_id': doc_id,
                                'second_doc_id': p_doc_id
                            }
                            qd_pairs.append(raw_item)
                else:
                    for doc_id in search_item['docs'][:topk]:
                        raw_item = {
                            'description_id': desc_id,
                            'query': query_text,
                            'doc_id': doc_id
                        }
                        qd_pairs.append(raw_item)

                return get_chunk(qd_pairs, self.batch_size)

            data = map(build_batch, read_jsonline_lazy(search_filename))
            data = chain.from_iterable(data)
        else:
            raise ValueError('data type error')
        return data
示例#2
0
    def predict(self, src_filename, dest_filename):
        self.model.eval()
        existed_ids = set()
        for item in read_jsonline_lazy(dest_filename, default=[]):
            existed_ids.add(item[self.id_field])

        loader = VectorizationDataLoader(src_filename, self.tokenizer, self.args)
        cos = nn.CosineSimilarity(dim=1)
        tp_count = 0

        total_count = 0
        for batch in loader:
            with torch.no_grad():
                query_embed = self.model(batch, 'query')
                true_embed = self.model(batch, 'true')
                false_embed = self.model(batch, 'false')
                true_scores = cos(query_embed, true_embed)
                false_scores = cos(query_embed, false_embed)
                print(true_scores, false_scores)
                total_count += query_embed.size(0)
                tp_count += (true_scores > false_scores).sum().cpu().numpy().tolist()

        accuray = tp_count / total_count

        return accuray
 def rerank(self, src_filename, dest_filename, vector_filename, topk=100):
     desc_id2vector = VectorIndexer(vector_filename).load_vector()
     for item in read_jsonline_lazy(src_filename):
         desc_id = item['description_id']
         paper_id_list = item['docs'][:topk]
         if desc_id in desc_id2vector:
             desc_vector = desc_id2vector[desc_id]
             paper_id_with_scores = []
             for paper_id in paper_id_list:
                 if paper_id not in self.paper_id2vector:
                     score = 0
                 else:
                     paper_vector = self.paper_id2vector[paper_id]
                     dot_score = np.dot(paper_vector, desc_vector)
                     norm_score = np.linalg.norm(
                         paper_vector) * np.linalg.norm(desc_vector)
                     score = dot_score / norm_score
                 paper_id_with_scores.append((paper_id, score))
             sorted_id_scores = sorted(paper_id_with_scores,
                                       key=lambda i: (i[1], i[0]),
                                       reverse=True)
             reranked_paper_id_list = [p for p, s in sorted_id_scores]
         else:
             reranked_paper_id_list = paper_id_list
         result_item = {
             'description_id': desc_id,
             'docs': reranked_paper_id_list
         }
         append_jsonline(dest_filename, result_item)
示例#4
0
    def build_data(self):
        pool = Pool(20)

        desc_id2item = {}

        for item in read_jsonline_lazy(self.golden_filename):
            desc_id = item['description_id']
            desc_id2item[desc_id] = item

        chunk_size = 50
        for item_chunk in get_chunk(read_jsonline_lazy(self.search_filename),
                                    chunk_size):
            new_item_chunk = []
            for item in item_chunk:
                true_item = desc_id2item[item['description_id']]
                true_paper_id = true_item['paper_id']
                cites_text = true_item['cites_text']
                docs = item['docs']
                item.pop('docs')
                item.pop('keywords')
                new_item_list = []
                new_item_dict = copy.deepcopy(item)
                new_item_dict['true_paper_id'] = true_paper_id
                new_item_dict['false_paper_id'] = []
                new_item_dict['cites_text'] = cites_text
                new_item_dict['description_text'] = true_item[
                    'description_text']
                for idx in range(self.args.sample_count):
                    train_pair = self.select_train_pair(
                        docs, true_paper_id, self.args.select_strategy, idx)
                    new_item = {
                        **train_pair,
                        **item, 'cites_text': cites_text,
                        'description_text': true_item['description_text']
                    }
                    new_item_list.append(new_item)
                    new_item_dict['false_paper_id'].append(
                        train_pair['false_paper_id'])
                if self.args.aggregate_sample:
                    new_item_chunk.append(new_item_dict)
                else:
                    new_item_chunk.extend(new_item_list)
            built_items = pool.map(self.build_single_query, new_item_chunk)
            built_items = [i for i in built_items if i]
            append_jsonlines(self.dest_filename, built_items)
示例#5
0
def result_format(src_filename, dest_filename=None):
    if dest_filename is None:
        dest_filename = os.path.splitext(src_filename)[0] + '.csv'
    for item in read_jsonline_lazy(src_filename):
        desc_id = item['description_id']
        paper_ids = item['docs'][:3]
        if not paper_ids:
            raise ValueError('result is empty')
        line = desc_id + '\t' + '\t'.join(paper_ids)
        append_line(dest_filename, line)
示例#6
0
 def view(self):
     for idx, item in enumerate(read_jsonline_lazy(DATA_DIR +
                                                   'test.jsonl')):
         if idx > 20:
             break
         paper = get_paper(item['paper_id'])
         print('============')
         print(item['description_text'])
         print('-------------')
         print(paper['title'])
         print(paper['abstract'])
示例#7
0
 def load_data(self, chunk_size):
     if isinstance(self.data_source, str):
         if self.lazy_loading:
             data = read_jsonline_lazy(self.data_source)
         else:
             data = read_jsonline(self.data_source)
             if self.loader.shuffle:
                 random.shuffle(data)
                 random.shuffle(data)
                 random.shuffle(data)
     elif isinstance(self.data_source, list):
         data = iter(self.data_source)
     else:
         raise TypeError('input filename type is error')
     return get_chunk(data, chunk_size)
示例#8
0
 def process(self):
     pool = Pool(self.parallel_count)
     tokens = []
     chunk_size = 100
     for item_chunk in get_chunk(read_jsonline_lazy(self.src_filename), chunk_size):
         processed_records = pool.map(self.tokenize_record, item_chunk)
         if self.dest_vocab_path:
             for record in processed_records:
                 tokens.extend(record['title_and_abstract_tokens'] + record['flatten_keyword_tokens'])
         for record in processed_records:
             record.pop('flatten_keyword_tokens')
         append_jsonlines(self.dest_filename, processed_records)
     if self.dest_vocab_path:
         vocab = self.build_vocab(tokens)
         write_lines(self.dest_vocab_path, vocab)
示例#9
0
    def indexing_runner(self):
        filename = CANDIDATE_FILENAME
        pool = Pool(self.parallel_size)
        start = time.time()
        count = 0
        failed_doc_list = []
        for item_chunk in get_chunk(read_jsonline_lazy(filename), 500):
            ret = pool.map(self.index_doc, item_chunk)
            failed_doc_list.extend([i for i in ret if i])
            duration = time.time() - start
            count += len(item_chunk)
            msg = '{} completed, {}min {:.2f}s'.format(count, duration // 60,
                                                       duration % 60)
            self.logger.info(msg)

        for doc in failed_doc_list:
            self.index_doc(doc)
示例#10
0
 def get_searched_desc_id(self, filename):
     desc_ids = {item['description_id'] for item in read_jsonline_lazy(filename, default=[])}
     return desc_ids