def test_model(args, model, tokenizer, test_examples, test_features, device): test_indices = torch.arange(len(test_features), dtype=torch.long) test_sampler = SequentialSampler(test_indices) test_dataloader = DataLoader(test_indices, sampler=test_sampler, batch_size=args.predict_batch_size) model.eval() all_results = [] for step, batch_test_indices in enumerate( tqdm(test_dataloader, desc="Evaluating")): batch_test_features = [ test_features[ind] for ind in batch_test_indices ] batch_query_tokens = [f.query_tokens for f in batch_test_features] batch_doc_tokens = [f.doc_tokens for f in batch_test_features] batch_size = len(batch_test_features) cur_global_pointers = [0] * batch_size batch_max_doc_length = [ args.max_seq_length - 3 - len(query_tokens) for query_tokens in batch_query_tokens ] stop_probs = [] prev_hidden_states = None for t in range(args.max_read_times): chunk_input_ids, chunk_input_mask, chunk_segment_ids, id_to_tok_maps, _, _, _ = \ gen_model_features(cur_global_pointers, batch_query_tokens, batch_doc_tokens, \ None, None, batch_max_doc_length, args.max_seq_length, \ tokenizer, is_train=False) chunk_input_ids = torch.tensor(chunk_input_ids, dtype=torch.long, device=device) chunk_input_mask = torch.tensor(chunk_input_mask, dtype=torch.long, device=device) chunk_segment_ids = torch.tensor(chunk_segment_ids, dtype=torch.long, device=device) with torch.no_grad(): chunk_stop_logits, chunk_stride_inds, chunk_stride_log_probs, \ chunk_start_logits, chunk_end_logits, \ prev_hidden_states = model(chunk_input_ids, chunk_segment_ids, chunk_input_mask, prev_hidden_states) # stop_probs: current chunk contains answer chunk_stop_logits = chunk_stop_logits.detach().cpu().tolist() #chunk_stop_probs = chunk_stop_probs.detach().cpu().tolist() #stop_probs.append(chunk_stop_probs[:]) # find top answer texts for the current chunk for i, example_index in enumerate(batch_test_indices): stop_logits = chunk_stop_logits[i] start_logits = chunk_start_logits[i].detach().cpu().tolist() end_logits = chunk_end_logits[i].detach().cpu().tolist() #yes_no_flag_logits = chunk_yes_no_flag_logits[i].detach().cpu().tolist() #yes_no_ans_logits = chunk_yes_no_ans_logits[i].detach().cpu().tolist() id_to_tok_map = id_to_tok_maps[i] example_index = example_index.item() all_results.append( RawResult(example_index=example_index, stop_logits=stop_logits, start_logits=start_logits, end_logits=end_logits, id_to_tok_map=id_to_tok_map)) # take movement action if args.supervised_pretraining: chunk_strides = [args.doc_stride] * batch_size else: chunk_strides = [ stride_action_space[stride_ind] for stride_ind in chunk_stride_inds.tolist() ] cur_global_pointers = [ cur_global_pointers[ind] + chunk_strides[ind] for ind in range(len(cur_global_pointers)) ] # put pointer be put to 0 or the last doc token is it cur_global_pointers = [min(max(0, cur_global_pointers[ind]), len(batch_doc_tokens[ind])-1) \ for ind in range(len(cur_global_pointers))] # write predictions output_prediction_file = os.path.join(args.output_dir, "predictions.json") output_nbest_file = os.path.join(args.output_dir, "nbest_predictions.json") write_predictions(test_examples, test_features, all_results, args.n_best_size, \ args.max_answer_length, args.do_lower_case, \ output_prediction_file, output_nbest_file, args.verbose_logging)
def validate_model(args, model, tokenizer, dev_examples, dev_features, dev_dataloader, dev_evaluator, best_dev_score, device): all_results = [] for dev_step, batch_dev_indices in enumerate( tqdm(dev_dataloader, desc="Evaluating")): batch_dev_features = [dev_features[ind] for ind in batch_dev_indices] batch_query_tokens = [f.query_tokens for f in batch_dev_features] batch_doc_tokens = [f.doc_tokens for f in batch_dev_features] batch_size = len(batch_dev_features) cur_global_pointers = [0] * batch_size batch_max_doc_length = [ args.max_seq_length - 3 - len(query_tokens) for query_tokens in batch_query_tokens ] stop_probs = [] prev_hidden_states = None for t in range(args.max_read_times): chunk_input_ids, chunk_input_mask, chunk_segment_ids, id_to_tok_maps, _, _, _ = \ gen_model_features(cur_global_pointers, batch_query_tokens, batch_doc_tokens, \ None, None, batch_max_doc_length, args.max_seq_length, \ tokenizer, is_train=False) chunk_input_ids = torch.tensor(chunk_input_ids, dtype=torch.long, device=device) chunk_input_mask = torch.tensor(chunk_input_mask, dtype=torch.long, device=device) chunk_segment_ids = torch.tensor(chunk_segment_ids, dtype=torch.long, device=device) with torch.no_grad(): chunk_stop_logits, chunk_stride_inds, chunk_stride_log_probs, \ chunk_start_logits, chunk_end_logits, \ prev_hidden_states = model(chunk_input_ids, chunk_segment_ids, chunk_input_mask, prev_hidden_states) chunk_stop_logits = chunk_stop_logits.detach().cpu().tolist() # stop_probs: current chunk contains answer #chunk_stop_probs = chunk_stop_probs.detach().cpu().tolist() # find top answer texts for the current chunk for i, example_index in enumerate(batch_dev_indices): stop_logits = chunk_stop_logits[i] start_logits = chunk_start_logits[i].detach().cpu().tolist() end_logits = chunk_end_logits[i].detach().cpu().tolist() #yes_no_flag_logits = chunk_yes_no_flag_logits[i].detach().cpu().tolist() #yes_no_ans_logits = chunk_yes_no_ans_logits[i].detach().cpu().tolist() id_to_tok_map = id_to_tok_maps[i] example_index = example_index.item() all_results.append( RawResult(example_index=example_index, stop_logits=stop_logits, start_logits=start_logits, end_logits=end_logits, id_to_tok_map=id_to_tok_map)) # take movement action if args.supervised_pretraining: chunk_strides = [args.doc_stride] * batch_size else: chunk_strides = [ stride_action_space[stride_ind] for stride_ind in chunk_stride_inds.tolist() ] cur_global_pointers = [ cur_global_pointers[ind] + chunk_strides[ind] for ind in range(len(cur_global_pointers)) ] # put pointer be put to 0 or the last doc token is it cur_global_pointers = [min(max(0, cur_global_pointers[ind]), len(batch_doc_tokens[ind])-1) \ for ind in range(len(cur_global_pointers))] dev_predictions = make_predictions(dev_examples, dev_features, all_results, args.n_best_size, \ args.max_answer_length, args.do_lower_case, \ args.verbose_logging, validate_flag=True) dev_scores = dev_evaluator.eval_fn(dev_predictions) dev_score = dev_scores['f1'] logger.info('dev score: {}'.format(dev_score)) if (dev_score > best_dev_score): best_model_to_save = model.module if hasattr(model, 'module') else model best_output_model_file = os.path.join(args.output_dir, "best_RCM_model.bin") torch.save(best_model_to_save.state_dict(), best_output_model_file) best_dev_score = max(best_dev_score, dev_score) logger.info("Best dev score: {}, saved to best_RCM_model.bin".format( dev_score)) return best_dev_score
def train_model(args, model, tokenizer, optimizer, train_examples, train_features, dev_examples, dev_features, dev_evaluator, device, n_gpu, t_total): train_indices = torch.arange(len(train_features), dtype=torch.long) if args.local_rank == -1: train_sampler = RandomSampler(train_indices) else: train_sampler = DistributedSampler(train_indices) train_dataloader = DataLoader(train_indices, sampler=train_sampler, batch_size=args.train_batch_size) if args.do_validate: dev_indices = torch.arange(len(dev_features), dtype=torch.long) dev_sampler = SequentialSampler(dev_indices) dev_dataloader = DataLoader(dev_indices, sampler=dev_sampler, batch_size=args.predict_batch_size) best_dev_score = 0.0 epoch = 0 global_step = 0 model.train() for _ in trange(int(args.num_train_epochs), desc="Epoch"): training_loss = 0.0 for step, batch_indices in enumerate( tqdm(train_dataloader, desc="Iteration")): batch_features = [train_features[ind] for ind in batch_indices] batch_query_tokens = [f.query_tokens for f in batch_features] batch_doc_tokens = [f.doc_tokens for f in batch_features] batch_start_positions = [f.start_position for f in batch_features] batch_end_positions = [f.end_position for f in batch_features] #batch_yes_no_flags = [f.yes_no_flag for f in batch_features] #batch_yes_no_answers = [f.yes_no_ans for f in batch_features] batch_size = len(batch_features) cur_global_pointers = [ 0 ] * batch_size # global position of current pointer at the document batch_max_doc_length = [ args.max_seq_length - 3 - len(query_tokens) for query_tokens in batch_query_tokens ] stride_log_probs = [] stop_rewards = [] stop_probs = [] stop_loss = None answer_loss = None prev_hidden_states = None for t in range(args.max_read_times): # features at the current chunk chunk_input_ids, chunk_input_mask, chunk_segment_ids, id_to_tok_maps, \ chunk_start_positions, chunk_end_positions, chunk_stop_flags = \ gen_model_features(cur_global_pointers, batch_query_tokens, batch_doc_tokens, \ batch_start_positions, batch_end_positions, batch_max_doc_length, \ args.max_seq_length, tokenizer, is_train=True) chunk_input_ids = torch.tensor(chunk_input_ids, dtype=torch.long, device=device) chunk_input_mask = torch.tensor(chunk_input_mask, dtype=torch.long, device=device) chunk_segment_ids = torch.tensor(chunk_segment_ids, dtype=torch.long, device=device) chunk_start_positions = torch.tensor(chunk_start_positions, dtype=torch.long, device=device) chunk_end_positions = torch.tensor(chunk_end_positions, dtype=torch.long, device=device) #chunk_yes_no_flags = torch.tensor(batch_yes_no_flags, dtype=torch.long, device=device) #chunk_yes_no_answers = torch.tensor(batch_yes_no_answers, dtype=torch.long, device=device) chunk_stop_flags = torch.tensor(chunk_stop_flags, dtype=torch.long, device=device) # model to find span chunk_stop_logits, chunk_stride_inds, chunk_stride_log_probs, \ chunk_start_logits, chunk_end_logits, \ prev_hidden_states, chunk_stop_loss, chunk_answer_loss = \ model(chunk_input_ids, chunk_segment_ids, chunk_input_mask, prev_hidden_states, chunk_stop_flags, chunk_start_positions, chunk_end_positions) chunk_stop_logits = chunk_stop_logits.detach() chunk_stop_probs = F.softmax(chunk_stop_logits, dim=1) chunk_stop_probs = chunk_stop_probs[:, 1] stop_probs.append(chunk_stop_probs.tolist()) chunk_stop_logits = chunk_stop_logits.tolist() if stop_loss is None: stop_loss = chunk_stop_loss else: stop_loss += chunk_stop_loss if answer_loss is None: answer_loss = chunk_answer_loss else: answer_loss += chunk_answer_loss if args.supervised_pretraining: chunk_strides = [args.doc_stride] * batch_size else: # take movement action chunk_strides = [ stride_action_space[stride_ind] for stride_ind in chunk_stride_inds.tolist() ] cur_global_pointers = [ cur_global_pointers[ind] + chunk_strides[ind] for ind in range(len(cur_global_pointers)) ] # put pointer to 0 or the last doc token cur_global_pointers = [min(max(0, cur_global_pointers[ind]), len(batch_doc_tokens[ind])-1) \ for ind in range(len(cur_global_pointers))] if not args.supervised_pretraining: # reward estimation for reinforcement learning chunk_start_probs = F.softmax(chunk_start_logits.detach(), dim=1).tolist() chunk_end_probs = F.softmax(chunk_end_logits.detach(), dim=1).tolist() #chunk_yes_no_flag_probs = F.softmax(chunk_yes_no_flag_logits.detach(), dim=1).tolist() #chunk_yes_no_ans_probs = F.softmax(chunk_yes_no_ans_logits.detach(), dim=1).tolist() # rewards if stop at the current chunk chunk_stop_rewards = reward_estimation_for_stop( chunk_start_probs, chunk_end_probs, chunk_start_positions.tolist(), chunk_end_positions.tolist(), chunk_stop_flags.tolist()) stop_rewards.append(chunk_stop_rewards) # save history (exclude the prob of the last read since the last action is not evaluated) if (t < args.max_read_times - 1): stride_log_probs.append(chunk_stride_log_probs) if args.supervised_pretraining: loss = (stop_loss * args.stop_loss_weight + answer_loss) / args.max_read_times else: # stride_log_probs: (bsz, max_read_times-1) stride_log_probs = torch.stack(stride_log_probs).transpose( 1, 0) # q_vals: (bsz, max_read_times-1) q_vals = reward_estimation(stop_rewards, stop_probs) q_vals = torch.tensor(q_vals, dtype=stride_log_probs.dtype, device=device) #logger.info("q_vals: {}".format(q_vals)) reinforce_loss = torch.sum(-stride_log_probs * q_vals, dim=1) reinforce_loss = torch.mean(reinforce_loss, dim=0) loss = (stop_loss * args.stop_loss_weight + answer_loss) / args.max_read_times + reinforce_loss # compute gradients if n_gpu > 1: loss = loss.mean() # mean() to average on multi-gpu. if args.gradient_accumulation_steps > 1: loss = loss / args.gradient_accumulation_steps if args.fp16: optimizer.backward(loss) else: loss.backward() # logging training loss training_loss += loss.item() if (step % 500 == 499): logger.info('step: {}, train loss: {}\n'.format( step, training_loss / 500.0)) if not args.supervised_pretraining: logger.info('q_vals: {}\n'.format(q_vals)) training_loss = 0.0 # validation on dev data if args.do_validate and step % 500 == 499: model.eval() best_dev_score = validate_model(args, model, tokenizer, dev_examples, dev_features, dev_dataloader, dev_evaluator, best_dev_score, device) model.train() if (step + 1) % args.gradient_accumulation_steps == 0: # modify learning rate with special warm up BERT uses lr_this_step = args.learning_rate * warmup_linear( global_step / t_total, args.warmup_proportion) for param_group in optimizer.param_groups: param_group['lr'] = lr_this_step optimizer.step() optimizer.zero_grad() global_step += 1 epoch += 1 # Save a trained model model_to_save = model.module if hasattr( model, 'module') else model # Only save the model it-self output_model_file = os.path.join(args.output_dir, "pytorch_model.bin") if args.do_train: torch.save(model_to_save.state_dict(), output_model_file)