def main(): parser = argparse.ArgumentParser() BERT_DIR = "./model/uncased_L-12_H-768_A-12/" ## Required parameters parser.add_argument("--bert_config_file", default=BERT_DIR+"bert_config.json", \ type=str, help="The config json file corresponding to the pre-trained BERT model. " "This specifies the model architecture.") parser.add_argument("--vocab_file", default=BERT_DIR+"vocab.txt", type=str, \ help="The vocabulary file that the BERT model was trained on.") parser.add_argument("--output_dir", default="out", type=str, \ help="The output directory where the model checkpoints will be written.") ## Other parameters parser.add_argument("--train_file", type=str, \ help="SQuAD json for training. E.g., train-v1.1.json", \ default="") parser.add_argument("--predict_file", type=str, help="SQuAD json for predictions. E.g., dev-v1.1.json or test-v1.1.json", \ default="") parser.add_argument("--init_checkpoint", type=str, help="Initial checkpoint (usually from a pre-trained BERT model).", \ default=BERT_DIR+"pytorch_model.bin") parser.add_argument( "--do_lower_case", default=True, action='store_true', help="Whether to lower case the input text. Should be True for uncased " "models and False for cased models.") parser.add_argument( "--max_seq_length", default=300, type=int, help= "The maximum total input sequence length after WordPiece tokenization. Sequences " "longer than this will be truncated, and sequences shorter than this will be padded." ) parser.add_argument( "--doc_stride", default=128, type=int, help= "When splitting up a long document into chunks, how much stride to take between chunks." ) parser.add_argument( "--max_query_length", default=64, type=int, help= "The maximum number of tokens for the question. Questions longer than this will " "be truncated to this length.") parser.add_argument("--do_train", default=False, action='store_true', help="Whether to run training.") parser.add_argument("--do_predict", default=False, action='store_true', help="Whether to run eval on the dev set.") parser.add_argument("--train_batch_size", default=32, type=int, help="Total batch size for training.") parser.add_argument("--predict_batch_size", default=128, type=int, help="Total batch size for predictions.") parser.add_argument("--learning_rate", default=5e-5, type=float, help="The initial learning rate for Adam.") parser.add_argument("--num_train_epochs", default=10.0, type=float, help="Total number of training epochs to perform.") parser.add_argument( "--warmup_proportion", default=0.1, type=float, help= "Proportion of training to perform linear learning rate warmup for. E.g., 0.1 = 10% " "of training.") parser.add_argument("--save_checkpoints_steps", default=1000, type=int, help="How often to save the model checkpoint.") parser.add_argument("--iterations_per_loop", default=1000, type=int, help="How many steps to make in each estimator call.") parser.add_argument( "--n_best_size", default=3, type=int, help= "The total number of n-best predictions to generate in the nbest_predictions.json " "output file.") parser.add_argument( "--max_answer_length", default=30, type=int, help= "The maximum length of an answer that can be generated. This is needed because the start " "and end predictions are not conditioned on one another.") parser.add_argument( "--verbose_logging", default=False, action='store_true', help= "If true, all of the warnings related to data processing will be printed. " "A number of warnings are expected for a normal SQuAD evaluation.") parser.add_argument("--no_cuda", default=False, action='store_true', help="Whether not to use CUDA when available") parser.add_argument("--local_rank", type=int, default=-1, help="local_rank for distributed training on gpus") parser.add_argument( "--accumulate_gradients", type=int, default=1, help= "Number of steps to accumulate gradient on (divide the batch_size and accumulate)" ) parser.add_argument('--seed', type=int, default=42, help="random seed for initialization") parser.add_argument( '--gradient_accumulation_steps', type=int, default=1, help= "Number of updates steps to accumualte before performing a backward/update pass." ) parser.add_argument('--eval_period', type=int, default=2000) parser.add_argument('--max_n_answers', type=int, default=5) parser.add_argument('--merge_query', type=int, default=-1) parser.add_argument('--reduce_layers', type=int, default=-1) parser.add_argument('--reduce_layers_to_tune', type=int, default=-1) parser.add_argument('--only_comp', action="store_true", default=False) parser.add_argument('--train_subqueries_file', type=str, default="") #500 parser.add_argument('--predict_subqueries_file', type=str, default="") #500 parser.add_argument('--prefix', type=str, default="") #500 parser.add_argument('--model', type=str, default="qa") #500 parser.add_argument('--pooling', type=str, default="max") parser.add_argument('--debug', action="store_true", default=False) parser.add_argument('--output_dropout_prob', type=float, default=0) parser.add_argument('--wait_step', type=int, default=30) parser.add_argument('--with_key', action="store_true", default=False) parser.add_argument('--add_noise', action="store_true", default=False) args = parser.parse_args() if args.local_rank == -1 or args.no_cuda: device = torch.device("cuda" if torch.cuda.is_available() and not args.no_cuda else "cpu") n_gpu = torch.cuda.device_count() else: device = torch.device("cuda", args.local_rank) n_gpu = 1 # Initializes the distributed backend which will take care of sychronizing nodes/GPUs torch.distributed.init_process_group(backend='nccl') logger.info("device %s n_gpu %d distributed training %r", device, n_gpu, bool(args.local_rank != -1)) if args.accumulate_gradients < 1: raise ValueError( "Invalid accumulate_gradients parameter: {}, should be >= 1". format(args.accumulate_gradients)) args.train_batch_size = int(args.train_batch_size / args.accumulate_gradients) random.seed(args.seed) np.random.seed(args.seed) torch.manual_seed(args.seed) if n_gpu > 0: torch.cuda.manual_seed_all(args.seed) if not args.do_train and not args.do_predict: raise ValueError( "At least one of `do_train` or `do_predict` must be True.") if args.do_train: if not args.train_file: raise ValueError( "If `do_train` is True, then `train_file` must be specified.") if not args.predict_file: raise ValueError( "If `do_train` is True, then `predict_file` must be specified." ) if args.do_predict: if not args.predict_file: raise ValueError( "If `do_predict` is True, then `predict_file` must be specified." ) bert_config = BertConfig.from_json_file(args.bert_config_file) if args.do_train and args.max_seq_length > bert_config.max_position_embeddings: raise ValueError( "Cannot use sequence length %d because the BERT model " "was only trained up to sequence length %d" % (args.max_seq_length, bert_config.max_position_embeddings)) if os.path.exists(args.output_dir) and os.listdir(args.output_dir): logger.info("Output directory () already exists and is not empty.") if not os.path.exists(args.output_dir): os.makedirs(args.output_dir, exist_ok=True) tokenizer = tokenization.FullTokenizer(vocab_file=args.vocab_file, do_lower_case=args.do_lower_case) train_examples = None num_train_steps = None eval_dataloader, eval_examples, eval_features, _ = get_dataloader( logger=logger, args=args, input_file=args.predict_file, subqueries_file=args.predict_subqueries_file, is_training=False, batch_size=args.predict_batch_size, num_epochs=1, tokenizer=tokenizer) if args.do_train: train_dataloader, train_examples, _, num_train_steps = get_dataloader( logger=logger, args=args, \ input_file=args.train_file, \ subqueries_file=args.train_subqueries_file, \ is_training=True, batch_size=args.train_batch_size, num_epochs=args.num_train_epochs, tokenizer=tokenizer) #a = input() if args.model == 'qa': model = BertForQuestionAnswering(bert_config, 4) metric_name = "F1" elif args.model == 'classifier': if args.reduce_layers != -1: bert_config.num_hidden_layers = args.reduce_layers model = BertClassifier(bert_config, 2, args.pooling) metric_name = "F1" elif args.model == "span-predictor": if args.reduce_layers != -1: bert_config.num_hidden_layers = args.reduce_layers if args.with_key: Model = BertForQuestionAnsweringWithKeyword else: Model = BertForQuestionAnswering model = Model(bert_config, 2) metric_name = "Accuracy" else: raise NotImplementedError() if args.init_checkpoint is not None and args.do_predict and \ len(args.init_checkpoint.split(','))>1: assert args.model == "qa" model = [model] for i, checkpoint in enumerate(args.init_checkpoint.split(',')): if i > 0: model.append(BertForQuestionAnswering(bert_config, 4)) print("Loading from", checkpoint) state_dict = torch.load(checkpoint, map_location='cpu') filter = lambda x: x[7:] if x.startswith('module.') else x state_dict = {filter(k): v for (k, v) in state_dict.items()} model[-1].load_state_dict(state_dict) model[-1].to(device) else: if args.init_checkpoint is not None: print("Loading from", args.init_checkpoint) state_dict = torch.load(args.init_checkpoint, map_location='cpu') if args.reduce_layers != -1: state_dict = {k:v for k, v in state_dict.items() \ if not '.'.join(k.split('.')[:3]) in \ ['encoder.layer.{}'.format(i) for i in range(args.reduce_layers, 12)]} if args.do_predict: filter = lambda x: x[7:] if x.startswith('module.') else x state_dict = {filter(k): v for (k, v) in state_dict.items()} model.load_state_dict(state_dict) else: model.bert.load_state_dict(state_dict) if args.reduce_layers_to_tune != -1: model.bert.embeddings.required_grad = False n_layers = 12 if args.reduce_layers == -1 else args.reduce_layers for i in range(n_layers - args.reduce_layers_to_tune): model.bert.encoder.layer[i].require_grad = False model.to(device) if args.local_rank != -1: model = torch.nn.parallel.DistributedDataParallel( model, device_ids=[args.local_rank], output_device=args.local_rank) elif n_gpu > 1: model = torch.nn.DataParallel(model) if args.do_train: no_decay = ['bias', 'gamma', 'beta'] optimizer_parameters = [{ 'params': [p for n, p in model.named_parameters() if n not in no_decay], 'weight_decay_rate': 0.01 }, { 'params': [p for n, p in model.named_parameters() if n in no_decay], 'weight_decay_rate': 0.0 }] optimizer = BERTAdam(optimizer_parameters, lr=args.learning_rate, warmup=args.warmup_proportion, t_total=num_train_steps) global_step = 0 best_f1 = 0 wait_step = 0 model.train() global_step = 0 stop_training = False for epoch in range(int(args.num_train_epochs)): for step, batch in tqdm(enumerate(train_dataloader)): global_step += 1 batch = [t.to(device) for t in batch] loss = model(batch, global_step) if n_gpu > 1: loss = loss.mean() # mean() to average on multi-gpu. if args.gradient_accumulation_steps > 1: loss = loss / args.gradient_accumulation_steps loss.backward() if global_step % args.gradient_accumulation_steps == 0: optimizer.step() # We have accumulated enought gradients model.zero_grad() if global_step % args.eval_period == 0: model.eval() f1 = predict(args, model, eval_dataloader, eval_examples, eval_features, \ device, write_prediction=False) logger.info("%s: %.3f on epoch=%d" % (metric_name, f1 * 100.0, epoch)) if best_f1 < f1: logger.info("Saving model with best %s: %.3f -> %.3f on epoch=%d" % \ (metric_name, best_f1*100.0, f1*100.0, epoch)) model_state_dict = { k: v.cpu() for (k, v) in model.state_dict().items() } torch.save( model_state_dict, os.path.join(args.output_dir, "best-model.pt")) model = model.cuda() best_f1 = f1 wait_step = 0 stop_training = False else: wait_step += 1 if best_f1 > 0.1 and wait_step == args.wait_step: stop_training = True model.train() if stop_training: break elif args.do_predict: if type(model) == list: model = [m.eval() for m in model] else: model.eval() f1 = predict(args, model, eval_dataloader, eval_examples, eval_features, device) logger.info("Final %s score: %.3f%%" % (metric_name, f1 * 100.0))
def main(): parser = argparse.ArgumentParser() parser.add_argument("--data_dir", default='./data', type=str, help="Data dir containing model dir, etc.") parser.add_argument( "--tfidf_file", default= 'wiki_first_paras-tfidf-ngram=2-hash=16777216-tokenizer=spacy.npz', type=str, help="td-idf .npz file placed inside the data_dir.") parser.add_argument( "--wiki_jsonl", default='wiki_firstpara_sents.jsonl', type=str, help="Processed wikipedia .jsonl placed inside data_dir.") parser.add_argument("--qdmr_jsonl", default='./data/qdmr_data/qdmrs_hotpotqa_gold.jsonl', type=str, help="Path to processed qdmr .jsonl file.") parser.add_argument("--predict_batch_size", default=128, type=int, help="Batch size for predictions in eval mode.") parser.add_argument("--tasks", default='break_rc,ques_ir,break_ir', type=str, help="The IR, RC tasks to perform.") parser.add_argument("--suffix", default='gold', type=str, help="Suffix to add to the output files.") parser.add_argument("--debug", action='store_true', help="If on, only keep a small number of qdmrs.") parser.add_argument( "--input_results_file", default='', type=str, help="File containing results of the task to be reused.") args = parser.parse_args() # we use an already finetuned single-hop RC ensemble by # Min et al (https://github.com/shmsw25/DecompRC/tree/master/DecompRC) rc_args = { 'bert_config_file': 'data/onehop_rc/uncased_L-12_H-768_A-12/bert_config.json', 'do_lower_case': True, 'doc_stride': 128, 'init_checkpoint': f'{args.data_dir}/onehop_rc/uncased_L-12_H-768_A-12/model1.pt,{args.data_dir}/onehop_rc/uncased_L-12_H-768_A-12/model2.pt,{args.data_dir}/onehop_rc/uncased_L-12_H-768_A-12/model3.pt', 'iterations_per_loop': 1000, 'local_rank': -1, 'max_answer_length': 30, 'max_n_answers': 5, 'max_query_length': 64, 'max_seq_length': 300, 'model': 'qa', 'n_best_size': 4, 'no_cuda': False, 'output_dropout_prob': 0, 'pooling': 'max', 'seed': 42, 'verbose_logging': False, 'vocab_file': 'data/onehop_rc/uncased_L-12_H-768_A-12/vocab.txt', 'with_key': False } rc_args = SimpleNamespace(**rc_args) # load hotpotQA logging.info(f'loading datasets from {args.data_dir}/hotpot_data/ ...') data = read_file(f'{args.data_dir}/hotpot_data/hotpot_train_v1.json') #data += read_file(f'{args.data_dir}/hotpot_data/hotpot_dev_distractor_v1.json') data += read_file( f'{args.data_dir}/hotpot_data/hotpot_dev_fullwiki_v1.json') for d in data: d['gold_titles'] = {x[0] for x in d['supporting_facts']} hotpot = {d['_id']: d for d in data} # load qdmr data processed using prepare_break.jsonl qdmr_path = args.qdmr_jsonl logging.info(f'loading processed qdmr data from {qdmr_path} ...') qdmrs = read_file(qdmr_path) # load spacy nlp = en_core_web_sm.load() # spacy tokenize = lambda s: [x.text for x in nlp.tokenizer(s)] # load IR logging.info('loading IR ...') ranker = IR(tfidf_path=f'{args.data_dir}/{args.tfidf_file}') # load wikipedia wiki_path = f'{args.data_dir}/{args.wiki_jsonl}' logging.info(f'loading wikipedia from {wiki_path} ...') with jsonlines.open(wiki_path, 'r') as reader: wiki = {d['title']: d['para'] for d in tqdm(reader.iter())} # prepare and load the RC for inference device = torch.device("cuda") n_gpu = torch.cuda.device_count() logging.info(f'{n_gpu} cuda devices available.') logging.info('loading 1-hop RC ensemble ...') random.seed(rc_args.seed) np.random.seed(rc_args.seed) torch.manual_seed(rc_args.seed) if n_gpu > 0: torch.cuda.manual_seed_all(rc_args.seed) tokenizer = tokenization.FullTokenizer(vocab_file=rc_args.vocab_file, do_lower_case=rc_args.do_lower_case) bert_config = BertConfig.from_json_file(rc_args.bert_config_file) model = BertForQuestionAnswering(bert_config, 4) if rc_args.init_checkpoint is not None: model = [model] for i, checkpoint in enumerate(rc_args.init_checkpoint.split(',')): if i > 0: model.append(BertForQuestionAnswering(bert_config, 4)) logging.info(f"Loading from {checkpoint}") state_dict = torch.load(checkpoint, map_location='cpu') filter = lambda x: x[7:] if x.startswith('module.') else x state_dict = {filter(k): v for (k, v) in state_dict.items()} model[-1].load_state_dict(state_dict) model[-1].to(device) if type(model) == list: model = [m.eval() for m in model] else: model.eval() # 1hop RC wrapper simpQA = partial(_simpQA, args=args, rc_args=rc_args, tokenizer=tokenizer, model=model, device=device) if args.input_results_file: logging.info( f'Reading the supplied results file {args.input_results_file} ...') all_results = read_file(args.input_results_file) else: all_results = {} for i_d, d in enumerate(qdmrs): [_, data_split, _id] = d['question_id'].split('_') #assert data_split == 'dev' assert _id in hotpot all_results[_id] = d assert d['steps'] and d['op_types'], print( 'QDMRs must be pre-processed and non-empty.') if args.debug: all_results = { key: val for key, val in all_results.items() if random.random() < 0.01 } logging.info(f'\nTruncating to only {len(all_results)} samples!!!\n') tasks = [x.strip() for x in args.tasks.split(',')] if 'break_rc' in tasks: logging.info(f'Running BREAK IR+RC on {len(all_results)} samples ...') max_n_parts = max([len(v['steps']) for v in all_results.values()]) for i_p in range(max_n_parts): logging.info(f'Processing qdmr step #{i_p} ...') # process the i_p'th part of all samples articles = [] # hotpot articles corresponding to queries to the RC for _id, v in tqdm(all_results.items()): parts = v['steps'] if i_p >= (len(parts) - int( v['op_types'][-1] in ['COMPARISON', 'INTERSECTION'])): # the last discrete comparison, intersection step is processed later continue rc_outputs = v['rc_outputs'] if 'rc_outputs' in v else {} nbest_outputs = v[ 'nbest_outputs'] if 'nbest_outputs' in v else {} l_top = v['titles'] if 'titles' in v else [] part = parts[i_p] # replace placeholders with the respective RC outputs of previous parts for j in range(i_p): ph = '#' + str(j + 1) # 1...i_p if ph in part: part = part.replace(ph, rc_outputs[ph]) # get top 10 titles from IR top_titles = ranker.closest_docs(part, k=10)[0] l_top.append(top_titles) v.update({ 'titles': l_top, 'rc_outputs': rc_outputs, 'nbest_outputs': nbest_outputs }) context = [] # use all retrieved para for the sample instead of just current 10 & sort them acc to similarity wrt part set_l_top = set(sum(l_top, [])) scores = ranker.rank_titles(part, set_l_top) sorted_l_top = sorted(scores.keys(), key=lambda title: scores[title], reverse=True) for title in sorted_l_top: context.append([title, wiki[title]['sents'] ]) # get para from wiki if not sorted_l_top: # rare case of no valid titles context = [['Random Title 1', 'Random Text 1'], ['Random Title 2', 'Random Text 2']] d, article = hotpot[_id], {} article['question'], article[ 'context'] = part + ' ?', context # appending '?' to part query article.update({ k: d[k] for k in ['_id', 'type', 'answer'] }) # '_id', 'type', 'context', 'question', 'answer' articles.append(article) if not articles: continue # querying the 1-hop RC all_nbest_out = simpQA([to_squad(article) for article in articles])[1] for _id, v in all_results.items(): if _id not in all_nbest_out: continue nbest_i_p = all_nbest_out[_id] op = v['op_types'][i_p] nbest_id = v['nbest_outputs'] # handle filter steps if 'FILTER' in op: ref_ph = op.split('_')[1] nbest_ref = Counter(nbest_id[ref_ph]) # accumulating the logits of nbest of the part and the ref part nbest_ref.update(nbest_i_p) nbest_i_p = dict(nbest_ref) rc_out = max(nbest_i_p.keys(), key=lambda key: nbest_i_p[key]) v['rc_outputs'][f'#{i_p+1}'] = rc_out v['nbest_outputs'][f'#{i_p+1}'] = nbest_i_p # discrete processing of the last comparison step logging.info( f'Discrete processing of the last comparison/intersection steps ...' ) for _id, v in all_results.items(): if v['op_types'][-1] not in ['COMPARISON', 'INTERSECTION']: continue question, answer, gold_titles = hotpot[_id]['question'], hotpot[ _id]['answer'], hotpot[_id]['gold_titles'] parts, rc_outputs = v['steps'], v['rc_outputs'] if v['op_types'][-1] == 'COMPARISON': ents, rc_outs = [], [] for i_p, part in enumerate(parts[:-1]): # get named entity in the part part_without_phs = part for x in ['#' + str(j) for j in range(1, 8)]: part_without_phs = part_without_phs.replace(x, '') ent = get_ent(part_without_phs, nlp, only_longest=True) ent = '' if ent is None else ent ents.append(ent) rc_outs.append( normalize_answer(rc_outputs['#' + str(i_p + 1)])) if 'same as' in parts[-1]: pred_ans = 'yes' if rc_outs[-2] == rc_outs[-1] else 'no' else: pred_ans = ents[compare(parts[-1], rc_outs[-2], rc_outs[-1])] v['rc_outputs'][f'#{len(parts)}'] = pred_ans elif v['op_types'][-1] == 'INTERSECTION': part = parts[-1] phs = [ '#' + str(j) for j in range(1, 10) if '#' + str(j) in part ] phs = list(set(phs)) # accumulate logits of the parts and take the argmax nbest_id = v['nbest_outputs'] nbest = Counter(nbest_id[phs[0]]) # accumulate logits for ph in phs[1:]: if ph in nbest_id: nbest.update(nbest_id[ph]) nbest = dict(nbest) pred_ans = max(nbest.keys(), key=lambda key: nbest[key]) v['rc_outputs'][f'#{len(parts)}'] = pred_ans v['nbest_outputs'][f'#{len(parts)}'] = nbest for v in all_results.values(): assert len(v['rc_outputs']) == len(v['steps']) if 'break_ir' in tasks: # this can only be run after break_rc task & requires all_results dict logging.info( f'Forming context using the titles used by Break RC for {len(all_results)} samples ...' ) # prepare hotpot-like data for Bert RC new_hotpot = [] for _id, v in tqdm(all_results.items()): d = hotpot[_id] d_new = deepcopy(d) used_titles = sum(v['titles'], []) # sort wrt similarity to ques scores = ranker.rank_titles(d['question'], set(used_titles)) titles = sorted(scores.keys(), key=lambda title: scores[title], reverse=True) context = [] for title in titles: context.append([title, wiki[title]['sents']]) d_new['context'] = context if 'gold_titles' in d_new: del d_new['gold_titles'] new_hotpot.append(d_new) out_break_ir_file = f'{args.data_dir}/hotpot_data/hotpot_after_break_ir_{args.suffix}.json' logging.info( f'Writing hotpot version with the Break IR context to {out_break_ir_file} ...' ) write_file(new_hotpot, out_break_ir_file) # store the retrieved titles for d in new_hotpot: all_results[d['_id']]['titles_found_by_break_rc'] = list( set([x[0] for x in d['context']])) if 'ques_ir' in tasks: # this can only be run after break_rc task & requires all_results dict formed # to determine the number of titles to be retrieved for each sample logging.info( f'Running baseline IR using the whole question for {len(all_results)} samples ...' ) # prepare hotpot-like data for Bert RC new_hotpot = [] for _id in tqdm(all_results.keys()): d = hotpot[_id] d_new = deepcopy(d) # for fair comparison with Break RC retrieve the same number of titles n_titles = len(sum(all_results[_id]['titles'], [])) titles = ranker.closest_docs(d['question'], k=n_titles)[0] context = [] for title in titles: context.append([title, wiki[title]['sents']]) d_new['context'] = context if 'gold_titles' in d_new: del d_new['gold_titles'] new_hotpot.append(d_new) out_ques_ir_file = f'{args.data_dir}/hotpot_data/hotpot_after_ques_ir_{args.suffix}.json' logging.info( f'Writing hotpot version with the baseline IR context to {out_ques_ir_file} ...' ) write_file(new_hotpot, out_ques_ir_file) # store the retrieved titles for d in new_hotpot: all_results[d['_id']]['titles_found_using_whole_ques'] = list( set([x[0] for x in d['context']])) # save the Break RC outputs out_break_rc_file = f'{args.data_dir}/predictions/break_rc_results_{args.suffix}.json' logging.info(f'Writing the break RC results to {out_break_rc_file}...') os.makedirs(dirname(out_break_rc_file), exist_ok=True) write_file(all_results, out_break_rc_file)