def get_batches_raw(self, item_chunk, batch): batches = [] for item in item_chunk: if self.fix_batch_size: if batch and len(batch) > self.batch_size: tail_count = len(batch) % self.batch_size if tail_count: batch_chunk = batch[:-tail_count] batch = batch[-tail_count:] else: batch_chunk = batch batch = [] for sliced_batch in get_chunk(batch_chunk, self.batch_size): batches.append(sliced_batch) flatten_items = self.flatten_raw_item(item) batch.extend(flatten_items) else: if batch and len(batch) > self.batch_size: for sliced_batch in get_chunk(batch, self.batch_size): batches.append(sliced_batch) batch = [] flatten_items = self.flatten_raw_item(item) if batch and len(batch) + len(flatten_items) > self.batch_size: batches.append(batch) batch = flatten_items else: batch.extend(flatten_items) # batches = self.reorder_batch_list(batches) return batches, batch
def get_data(self): if isinstance(self.data_source, str): if self.args.lazy_loading: data = get_chunk(read_jsonline_lazy(self.data_source), self.batch_size) else: total_data = read_jsonline(self.data_source) random.shuffle(total_data) random.shuffle(total_data) random.shuffle(total_data) data = get_chunk(total_data, self.batch_size) elif isinstance(self.data_source, list): random.shuffle(self.data_source) random.shuffle(self.data_source) random.shuffle(self.data_source) data = get_chunk(self.data_source, self.batch_size) elif isinstance(self.data_source, dict): golden_filename = self.data_source['golden_filename'] desc_id2item = {} for item in read_jsonline_lazy(golden_filename): desc_id2item[item['description_id']] = item search_filename = self.data_source['search_filename'] topk = self.data_source['topk'] searched_ids = set(self.data_source['searched_id_list']) def build_batch(search_item): qd_pairs = [] desc_id = search_item['description_id'] if desc_id in searched_ids: return [[]] query_text = desc_id2item[desc_id][self.args.query_field] if self.args.rerank_model_name == 'pairwise': docs = search_item['docs'][:topk] for i, doc_id in enumerate(docs): for p_doc_id in docs[:i] + docs[i + 1:]: raw_item = { 'description_id': desc_id, 'query': query_text, 'first_doc_id': doc_id, 'second_doc_id': p_doc_id } qd_pairs.append(raw_item) else: for doc_id in search_item['docs'][:topk]: raw_item = { 'description_id': desc_id, 'query': query_text, 'doc_id': doc_id } qd_pairs.append(raw_item) return get_chunk(qd_pairs, self.batch_size) data = map(build_batch, read_jsonline_lazy(search_filename)) data = chain.from_iterable(data) else: raise ValueError('data type error') return data
def build_batch(search_item): qd_pairs = [] desc_id = search_item['description_id'] if desc_id in searched_ids: return [[]] query_text = desc_id2item[desc_id][self.args.query_field] if self.args.rerank_model_name == 'pairwise': docs = search_item['docs'][:topk] for i, doc_id in enumerate(docs): for p_doc_id in docs[:i] + docs[i + 1:]: raw_item = { 'description_id': desc_id, 'query': query_text, 'first_doc_id': doc_id, 'second_doc_id': p_doc_id } qd_pairs.append(raw_item) else: for doc_id in search_item['docs'][:topk]: raw_item = { 'description_id': desc_id, 'query': query_text, 'doc_id': doc_id } qd_pairs.append(raw_item) return get_chunk(qd_pairs, self.batch_size)
def load_data(self, chunk_size): if isinstance(self.data_source, str): if self.lazy_loading: data = read_jsonline_lazy(self.data_source) else: data = read_jsonline(self.data_source) if self.loader.shuffle: random.shuffle(data) random.shuffle(data) random.shuffle(data) elif isinstance(self.data_source, list): data = iter(self.data_source) else: raise TypeError('input filename type is error') return get_chunk(data, chunk_size)
def process(self): pool = Pool(self.parallel_count) tokens = [] chunk_size = 100 for item_chunk in get_chunk(read_jsonline_lazy(self.src_filename), chunk_size): processed_records = pool.map(self.tokenize_record, item_chunk) if self.dest_vocab_path: for record in processed_records: tokens.extend(record['title_and_abstract_tokens'] + record['flatten_keyword_tokens']) for record in processed_records: record.pop('flatten_keyword_tokens') append_jsonlines(self.dest_filename, processed_records) if self.dest_vocab_path: vocab = self.build_vocab(tokens) write_lines(self.dest_vocab_path, vocab)
def get_batches_raw(self, item_chunk, batch): batches = [] for item in item_chunk: if batch and len(batch) > self.batch_size: for sliced_batch in get_chunk(batch, self.batch_size): batches.append(sliced_batch) batch = [] flatten_items = self.flatten_raw_item(item) if batch and len(batch) + len(flatten_items) > self.batch_size: batches.append(batch) batch = flatten_items else: batch.extend(flatten_items) # batches = self.reorder_batch_list(batches) return batches, batch
def indexing_runner(self): filename = CANDIDATE_FILENAME pool = Pool(self.parallel_size) start = time.time() count = 0 failed_doc_list = [] for item_chunk in get_chunk(read_jsonline_lazy(filename), 500): ret = pool.map(self.index_doc, item_chunk) failed_doc_list.extend([i for i in ret if i]) duration = time.time() - start count += len(item_chunk) msg = '{} completed, {}min {:.2f}s'.format(count, duration // 60, duration % 60) self.logger.info(msg) for doc in failed_doc_list: self.index_doc(doc)
def build_data(self): pool = Pool(20) desc_id2item = {} for item in read_jsonline_lazy(self.golden_filename): desc_id = item['description_id'] desc_id2item[desc_id] = item chunk_size = 50 for item_chunk in get_chunk(read_jsonline_lazy(self.search_filename), chunk_size): new_item_chunk = [] for item in item_chunk: true_item = desc_id2item[item['description_id']] true_paper_id = true_item['paper_id'] cites_text = true_item['cites_text'] docs = item['docs'] item.pop('docs') item.pop('keywords') new_item_list = [] new_item_dict = copy.deepcopy(item) new_item_dict['true_paper_id'] = true_paper_id new_item_dict['false_paper_id'] = [] new_item_dict['cites_text'] = cites_text new_item_dict['description_text'] = true_item[ 'description_text'] for idx in range(self.args.sample_count): train_pair = self.select_train_pair( docs, true_paper_id, self.args.select_strategy, idx) new_item = { **train_pair, **item, 'cites_text': cites_text, 'description_text': true_item['description_text'] } new_item_list.append(new_item) new_item_dict['false_paper_id'].append( train_pair['false_paper_id']) if self.args.aggregate_sample: new_item_chunk.append(new_item_dict) else: new_item_chunk.extend(new_item_list) built_items = pool.map(self.build_single_query, new_item_chunk) built_items = [i for i in built_items if i] append_jsonlines(self.dest_filename, built_items)
def get_batches_processed(self, item_chunk, batch): batches = [] for new_batch in get_chunk(batch + item_chunk, self.batch_size): batches.append(new_batch) return batches, []