def __init__(self, args, model): torch.manual_seed(0) torch.autograd.set_detect_anomaly(True) self.args = args self.vocab2id = load_vocab(self.args.vocab_path, self.args.vocab_size) self.model = model if torch.cuda.is_available(): self.model = self.model.cuda() if args.train_parallel: self.model = nn.DataParallel(self.model) self.loss_func = nn.NLLLoss(ignore_index=self.vocab2id[PAD_WORD]) self.optimizer = optim.Adam(self.model.parameters(), lr=self.args.learning_rate) self.scheduler = optim.lr_scheduler.StepLR(self.optimizer, self.args.schedule_step, self.args.schedule_gamma) self.logger = get_logger('train') self.train_loader = KeyphraseDataLoader( data_source=self.args.train_filename, vocab2id=self.vocab2id, mode='train', args=args) if self.args.train_from: self.dest_dir = os.path.dirname(self.args.train_from) + '/' else: timemark = time.strftime('%Y%m%d-%H%M%S', time.localtime(time.time())) self.dest_dir = os.path.join( self.args.dest_base_dir, self.args.exp_name + '-' + timemark) + '/' os.mkdir(self.dest_dir) fh = logging.FileHandler(os.path.join(self.dest_dir, args.logfile)) fh.setLevel(logging.INFO) fh.setFormatter(logging.Formatter('[%(asctime)s] %(message)s')) self.logger.addHandler(fh) if not self.args.tensorboard_dir: tensorboard_dir = self.dest_dir + 'logs/' else: tensorboard_dir = self.args.tensorboard_dir self.writer = SummaryWriter(tensorboard_dir) self.eval_topn = (5, 10) self.macro_evaluator = KeyphraseEvaluator(self.eval_topn, 'macro', args.token_field, args.keyphrase_field) self.micro_evaluator = KeyphraseEvaluator(self.eval_topn, 'micro', args.token_field, args.keyphrase_field) self.best_f1 = None self.best_step = 0 self.not_update_count = 0
def __init__(self): self.args = parse_args() self.model = PlmModel(self.args) if torch.cuda.is_available(): self.model.cuda() self.tokenizer = self.model.tokenizer self.loader = VectorizationDataLoader(self.args.train_filename, self.model.tokenizer, args=self.args) self.logger = get_logger('vectorization trainer') if not os.path.exists(self.args.dest_base_dir): os.mkdir(self.args.dest_base_dir) self.logdir = self.args.dest_base_dir + '/logs' if not os.path.exists(self.logdir): os.mkdir(self.logdir) self.writer = SummaryWriter(self.logdir)
def __init__(self): self.args = self.parse_args() self.vocab2id = load_vocab(self.args.vocab_path) self.dest_base_dir = self.args.dest_base_dir self.writer = tf.summary.create_file_writer(self.dest_base_dir + '/logs') self.exp_name = self.args.exp_name self.pad_idx = self.vocab2id[PAD_WORD] self.eval_topn = (5, 10) self.macro_evaluator = KeyphraseEvaluator(self.eval_topn, 'macro', self.args.token_field, self.args.keyphrase_field) self.micro_evaluator = KeyphraseEvaluator(self.eval_topn, 'micro', self.args.token_field, self.args.keyphrase_field) self.best_f1 = None self.best_step = 0 self.not_update_count = 0 self.logger = get_logger(__name__) self.total_vocab_size = len(self.vocab2id) + self.args.max_oov_count
def __init__(self): self.args = parse_args() self.plm_model_name = self.args.plm_model_name self.rerank_model_name = self.args.rerank_model_name self.model_info = MODEL_DICT[self.plm_model_name] if 'path' in self.model_info: tokenizer_path = self.model_info['path'] + 'vocab.txt' else: tokenizer_path = self.plm_model_name self.tokenizer = self.model_info['tokenizer_class'].from_pretrained( tokenizer_path) dest_dir = self.args.dest_base_dir if not os.path.exists(dest_dir): os.mkdir(dest_dir) self.train_loader = RerankDataLoader(self.args.train_filename, self.tokenizer, self.args, 'train') self.logger = get_logger('rerank_trainer') logdir = self.args.dest_base_dir + 'logs/' if not os.path.exists(logdir): os.mkdir(logdir) self.writer = SummaryWriter(logdir)
class RerankDataBuilder(object): logger = get_logger('rank data builder') def __init__(self): self.args = self.parse_args() self.search_filename = self.args.search_filename self.golden_filename = self.args.golden_filename self.dest_filename = self.args.dest_filename if os.path.exists(self.dest_filename): os.remove(self.dest_filename) self.candidate_paper_id_list = read_lines(DATA_DIR + 'candidate_paper_id.txt') def parse_args(self): parser = argparse.ArgumentParser() parser.add_argument('-search_filename', type=str, required=True) parser.add_argument('-golden_filename', type=str, required=True) parser.add_argument('-dest_filename', type=str, required=True) parser.add_argument('-select_strategy', type=str, choices=[ 'random', 'search_result_offset', 'search_result_false_top' ], required=True) parser.add_argument('-query_field', type=str, default='cites_text', choices=['cites_text', 'description_text']) parser.add_argument('-sample_count', type=int, default=1) parser.add_argument('-aggregate_sample', action='store_true') parser.add_argument('-offset', type=int, default=50) args = parser.parse_args() return args def run(self): self.build_data() def build_data(self): pool = Pool(20) desc_id2item = {} for item in read_jsonline_lazy(self.golden_filename): desc_id = item['description_id'] desc_id2item[desc_id] = item chunk_size = 50 for item_chunk in get_chunk(read_jsonline_lazy(self.search_filename), chunk_size): new_item_chunk = [] for item in item_chunk: true_item = desc_id2item[item['description_id']] true_paper_id = true_item['paper_id'] cites_text = true_item['cites_text'] docs = item['docs'] item.pop('docs') item.pop('keywords') new_item_list = [] new_item_dict = copy.deepcopy(item) new_item_dict['true_paper_id'] = true_paper_id new_item_dict['false_paper_id'] = [] new_item_dict['cites_text'] = cites_text new_item_dict['description_text'] = true_item[ 'description_text'] for idx in range(self.args.sample_count): train_pair = self.select_train_pair( docs, true_paper_id, self.args.select_strategy, idx) new_item = { **train_pair, **item, 'cites_text': cites_text, 'description_text': true_item['description_text'] } new_item_list.append(new_item) new_item_dict['false_paper_id'].append( train_pair['false_paper_id']) if self.args.aggregate_sample: new_item_chunk.append(new_item_dict) else: new_item_chunk.extend(new_item_list) built_items = pool.map(self.build_single_query, new_item_chunk) built_items = [i for i in built_items if i] append_jsonlines(self.dest_filename, built_items) def select_train_pair(self, doc_list, true_doc_id, select_strategy, intra_offset): offset = self.args.offset + intra_offset if select_strategy == 'search_result_offset': true_idx = index(doc_list, true_doc_id, -1) if true_idx == -1 or true_idx + offset >= len(doc_list): if len(doc_list) <= offset: # when doc_list count is fewer than offset, used random selected false id # because the result caused by following reasons will drop training result # 1. the unusual description text will return 3 predefined paper id (in search.py) # 2. too small topk in benchmark, and last instance of this result list has similar context of true paper, will confused model false_paper_id = self.random_choose_false_id(true_doc_id) else: false_idx = -self.args.sample_count + intra_offset false_paper_id = doc_list[false_idx] else: false_paper_id = doc_list[true_idx + offset] elif select_strategy == 'random': false_paper_id = self.random_choose_false_id(true_doc_id) elif select_strategy == 'search_result_false_top': true_idx = index(doc_list, true_doc_id, -1) if true_idx == 0: false_paper_id = doc_list[1] else: false_paper_id = doc_list[0] else: raise ValueError('false instance select strategy error') return {'true_paper_id': true_doc_id, 'false_paper_id': false_paper_id} def random_choose_false_id(self, true_doc_id): false_paper_id = random.choice(self.candidate_paper_id_list) if false_paper_id == true_doc_id: while True: false_paper_id = random.choice(self.candidate_paper_id_list) if false_paper_id != true_doc_id: break return false_paper_id def build_single_query(self, item): query = item[self.args.query_field] true_paper = get_paper(item['true_paper_id']) true_text = true_paper['title'] + ' ' + true_paper['abstract'] if isinstance(item['false_paper_id'], str): false_paper = get_paper(item['false_paper_id']) false_text = false_paper['title'] + ' ' + false_paper['abstract'] elif isinstance(item['false_paper_id'], list): # false_paper = [] false_text = [] for pid in item['false_paper_id']: paper = get_paper(pid) # false_paper.append(paper) false_text.append(paper['title'] + ' ' + paper['abstract']) else: raise ValueError('false paper id type error') train_item = { 'query': query, 'true_doc': true_text, 'false_doc': false_text, **item } train_item.pop('cites_text') return train_item
class ElasticSearchIndexer(object): logger = get_logger('elastic indexer') parallel_size = 2 retry_count = 10 def __init__(self): self.base_url = ES_API_URL def run(self): self.delete_index() self.create_fields() self.indexing_runner() def create_fields(self): mapping_url = self.base_url + '/_mapping' headers = {"Content-Type": "application/json"} base_data = json.dumps(read_json(DATA_DIR + 'setting.json')) field_data = json.dumps(read_json(DATA_DIR + 'fields.json')) ret = requests.put(self.base_url, data=base_data, headers=headers) if ret.status_code != 200: raise Exception('setting es error, {}'.format(ret.text)) ret = requests.put(mapping_url, data=field_data, headers=headers) if ret.status_code != 200: raise Exception('create index error, {}'.format(ret.text)) self.logger.info('create index success') def delete_index(self): requests.delete(self.base_url) def indexing_runner(self): filename = CANDIDATE_FILENAME pool = Pool(self.parallel_size) start = time.time() count = 0 failed_doc_list = [] for item_chunk in get_chunk(read_jsonline_lazy(filename), 500): ret = pool.map(self.index_doc, item_chunk) failed_doc_list.extend([i for i in ret if i]) duration = time.time() - start count += len(item_chunk) msg = '{} completed, {}min {:.2f}s'.format(count, duration // 60, duration % 60) self.logger.info(msg) for doc in failed_doc_list: self.index_doc(doc) def index_doc(self, doc): base_url = self.base_url + '/_doc/{}' headers = {"Content-Type": "application/json"} url = base_url.format(doc['paper_id']) if doc['keywords'] == '': doc['keywords'] = [] else: doc['keywords'] = doc['keywords'].split(';') doc['TA'] = doc.get('title', '') + ' ' + doc.get('abstract', '') keyword_str = ' '.join(doc['keywords']).lower() if keyword_str: doc['TAK'] = doc['TA'] + ' ' + keyword_str else: doc['TAK'] = doc['TA'] if not doc['paper_id'].strip(): return input_str = json.dumps(doc) if not input_str: return try: ret = requests.put(url, input_str, headers=headers) except: count = 0 is_success = False while count < self.retry_count: count += 1 try: ret = requests.put(url, input_str, headers=headers) except: continue is_success = True break if not is_success: err_msg = '{} index failed, {}'.format(doc['paper_id'], traceback.format_exc()) self.logger.error(err_msg) return doc if ret.status_code != 200 and ret.status_code != 201: msg = '{} error {} {}'.format(json.dumps(doc), ret.status_code, ret.text) self.logger.error(msg)