示例#1
0
    def get_batches_raw(self, item_chunk, batch):
        batches = []

        for item in item_chunk:
            if self.fix_batch_size:
                if batch and len(batch) > self.batch_size:
                    tail_count = len(batch) % self.batch_size
                    if tail_count:
                        batch_chunk = batch[:-tail_count]
                        batch = batch[-tail_count:]
                    else:
                        batch_chunk = batch
                        batch = []
                    for sliced_batch in get_chunk(batch_chunk,
                                                  self.batch_size):
                        batches.append(sliced_batch)

                flatten_items = self.flatten_raw_item(item)
                batch.extend(flatten_items)
            else:
                if batch and len(batch) > self.batch_size:
                    for sliced_batch in get_chunk(batch, self.batch_size):
                        batches.append(sliced_batch)
                    batch = []
                flatten_items = self.flatten_raw_item(item)
                if batch and len(batch) + len(flatten_items) > self.batch_size:
                    batches.append(batch)
                    batch = flatten_items
                else:
                    batch.extend(flatten_items)
        # batches = self.reorder_batch_list(batches)
        return batches, batch
示例#2
0
    def get_data(self):
        if isinstance(self.data_source, str):
            if self.args.lazy_loading:
                data = get_chunk(read_jsonline_lazy(self.data_source),
                                 self.batch_size)
            else:
                total_data = read_jsonline(self.data_source)
                random.shuffle(total_data)
                random.shuffle(total_data)
                random.shuffle(total_data)
                data = get_chunk(total_data, self.batch_size)
        elif isinstance(self.data_source, list):
            random.shuffle(self.data_source)
            random.shuffle(self.data_source)
            random.shuffle(self.data_source)
            data = get_chunk(self.data_source, self.batch_size)
        elif isinstance(self.data_source, dict):
            golden_filename = self.data_source['golden_filename']
            desc_id2item = {}
            for item in read_jsonline_lazy(golden_filename):
                desc_id2item[item['description_id']] = item
            search_filename = self.data_source['search_filename']
            topk = self.data_source['topk']
            searched_ids = set(self.data_source['searched_id_list'])

            def build_batch(search_item):
                qd_pairs = []
                desc_id = search_item['description_id']
                if desc_id in searched_ids:
                    return [[]]

                query_text = desc_id2item[desc_id][self.args.query_field]
                if self.args.rerank_model_name == 'pairwise':
                    docs = search_item['docs'][:topk]
                    for i, doc_id in enumerate(docs):
                        for p_doc_id in docs[:i] + docs[i + 1:]:
                            raw_item = {
                                'description_id': desc_id,
                                'query': query_text,
                                'first_doc_id': doc_id,
                                'second_doc_id': p_doc_id
                            }
                            qd_pairs.append(raw_item)
                else:
                    for doc_id in search_item['docs'][:topk]:
                        raw_item = {
                            'description_id': desc_id,
                            'query': query_text,
                            'doc_id': doc_id
                        }
                        qd_pairs.append(raw_item)

                return get_chunk(qd_pairs, self.batch_size)

            data = map(build_batch, read_jsonline_lazy(search_filename))
            data = chain.from_iterable(data)
        else:
            raise ValueError('data type error')
        return data
示例#3
0
            def build_batch(search_item):
                qd_pairs = []
                desc_id = search_item['description_id']
                if desc_id in searched_ids:
                    return [[]]

                query_text = desc_id2item[desc_id][self.args.query_field]
                if self.args.rerank_model_name == 'pairwise':
                    docs = search_item['docs'][:topk]
                    for i, doc_id in enumerate(docs):
                        for p_doc_id in docs[:i] + docs[i + 1:]:
                            raw_item = {
                                'description_id': desc_id,
                                'query': query_text,
                                'first_doc_id': doc_id,
                                'second_doc_id': p_doc_id
                            }
                            qd_pairs.append(raw_item)
                else:
                    for doc_id in search_item['docs'][:topk]:
                        raw_item = {
                            'description_id': desc_id,
                            'query': query_text,
                            'doc_id': doc_id
                        }
                        qd_pairs.append(raw_item)

                return get_chunk(qd_pairs, self.batch_size)
示例#4
0
 def load_data(self, chunk_size):
     if isinstance(self.data_source, str):
         if self.lazy_loading:
             data = read_jsonline_lazy(self.data_source)
         else:
             data = read_jsonline(self.data_source)
             if self.loader.shuffle:
                 random.shuffle(data)
                 random.shuffle(data)
                 random.shuffle(data)
     elif isinstance(self.data_source, list):
         data = iter(self.data_source)
     else:
         raise TypeError('input filename type is error')
     return get_chunk(data, chunk_size)
示例#5
0
 def process(self):
     pool = Pool(self.parallel_count)
     tokens = []
     chunk_size = 100
     for item_chunk in get_chunk(read_jsonline_lazy(self.src_filename), chunk_size):
         processed_records = pool.map(self.tokenize_record, item_chunk)
         if self.dest_vocab_path:
             for record in processed_records:
                 tokens.extend(record['title_and_abstract_tokens'] + record['flatten_keyword_tokens'])
         for record in processed_records:
             record.pop('flatten_keyword_tokens')
         append_jsonlines(self.dest_filename, processed_records)
     if self.dest_vocab_path:
         vocab = self.build_vocab(tokens)
         write_lines(self.dest_vocab_path, vocab)
示例#6
0
    def get_batches_raw(self, item_chunk, batch):
        batches = []

        for item in item_chunk:
            if batch and len(batch) > self.batch_size:
                for sliced_batch in get_chunk(batch, self.batch_size):
                    batches.append(sliced_batch)
                batch = []
            flatten_items = self.flatten_raw_item(item)
            if batch and len(batch) + len(flatten_items) > self.batch_size:
                batches.append(batch)
                batch = flatten_items
            else:
                batch.extend(flatten_items)
        # batches = self.reorder_batch_list(batches)
        return batches, batch
示例#7
0
    def indexing_runner(self):
        filename = CANDIDATE_FILENAME
        pool = Pool(self.parallel_size)
        start = time.time()
        count = 0
        failed_doc_list = []
        for item_chunk in get_chunk(read_jsonline_lazy(filename), 500):
            ret = pool.map(self.index_doc, item_chunk)
            failed_doc_list.extend([i for i in ret if i])
            duration = time.time() - start
            count += len(item_chunk)
            msg = '{} completed, {}min {:.2f}s'.format(count, duration // 60,
                                                       duration % 60)
            self.logger.info(msg)

        for doc in failed_doc_list:
            self.index_doc(doc)
示例#8
0
    def build_data(self):
        pool = Pool(20)

        desc_id2item = {}

        for item in read_jsonline_lazy(self.golden_filename):
            desc_id = item['description_id']
            desc_id2item[desc_id] = item

        chunk_size = 50
        for item_chunk in get_chunk(read_jsonline_lazy(self.search_filename),
                                    chunk_size):
            new_item_chunk = []
            for item in item_chunk:
                true_item = desc_id2item[item['description_id']]
                true_paper_id = true_item['paper_id']
                cites_text = true_item['cites_text']
                docs = item['docs']
                item.pop('docs')
                item.pop('keywords')
                new_item_list = []
                new_item_dict = copy.deepcopy(item)
                new_item_dict['true_paper_id'] = true_paper_id
                new_item_dict['false_paper_id'] = []
                new_item_dict['cites_text'] = cites_text
                new_item_dict['description_text'] = true_item[
                    'description_text']
                for idx in range(self.args.sample_count):
                    train_pair = self.select_train_pair(
                        docs, true_paper_id, self.args.select_strategy, idx)
                    new_item = {
                        **train_pair,
                        **item, 'cites_text': cites_text,
                        'description_text': true_item['description_text']
                    }
                    new_item_list.append(new_item)
                    new_item_dict['false_paper_id'].append(
                        train_pair['false_paper_id'])
                if self.args.aggregate_sample:
                    new_item_chunk.append(new_item_dict)
                else:
                    new_item_chunk.extend(new_item_list)
            built_items = pool.map(self.build_single_query, new_item_chunk)
            built_items = [i for i in built_items if i]
            append_jsonlines(self.dest_filename, built_items)
示例#9
0
 def get_batches_processed(self, item_chunk, batch):
     batches = []
     for new_batch in get_chunk(batch + item_chunk, self.batch_size):
         batches.append(new_batch)
     return batches, []