def gen_ctx_vectors( cfg: DictConfig, ctx_rows: List[Tuple[object, BiEncoderPassage]], model: nn.Module, tensorizer: Tensorizer, insert_title: bool = True, ) -> List[Tuple[object, np.array]]: n = len(ctx_rows) bsz = cfg.batch_size total = 0 results = [] for j, batch_start in enumerate(range(0, n, bsz)): batch = ctx_rows[batch_start : batch_start + bsz] batch_token_tensors = [ tensorizer.text_to_tensor( ctx[1].text, title=ctx[1].title if insert_title else None ) for ctx in batch ] ctx_ids_batch = move_to_device( torch.stack(batch_token_tensors, dim=0), cfg.device ) ctx_seg_batch = move_to_device(torch.zeros_like(ctx_ids_batch), cfg.device) ctx_attn_mask = move_to_device( tensorizer.get_attn_mask(ctx_ids_batch), cfg.device ) with torch.no_grad(): _, out, _ = model(ctx_ids_batch, ctx_seg_batch, ctx_attn_mask) out = out.cpu() ctx_ids = [r[0] for r in batch] extra_info = [] if len(batch[0]) > 3: extra_info = [r[3:] for r in batch] assert len(ctx_ids) == out.size(0) total += len(ctx_ids) # TODO: refactor to avoid 'if' if extra_info: results.extend( [ (ctx_ids[i], out[i].view(-1).numpy(), *extra_info[i]) for i in range(out.size(0)) ] ) else: results.extend( [(ctx_ids[i], out[i].view(-1).numpy()) for i in range(out.size(0))] ) if total % 10 == 0: logger.info("Encoded passages %d", total) return results
def gen_ctx_vectors( ctx_rows: List[Tuple[object, str, str]], model: nn.Module, tensorizer: Tensorizer, insert_title: bool = True) -> List[Tuple[object, np.array]]: n = len(ctx_rows) bsz = args.batch_size total = 0 results = [] for j, batch_start in enumerate(range(0, n, bsz)): all_txt = [] for ctx in ctx_rows[batch_start:batch_start + bsz]: if ctx[2]: txt = ['title:', ctx[2], 'context:', ctx[1]] else: txt = ['context:', ctx[1]] txt = ' '.join(txt) all_txt.append(txt) batch_token_tensors = [ tensorizer.text_to_tensor(txt, max_length=250) for txt in all_txt ] #batch_token_tensors = [tensorizer.text_to_tensor(ctx[1], title=ctx[2] if insert_title else None) for ctx in #original # ctx_rows[batch_start:batch_start + bsz]] #original ctx_ids_batch = move_to_device(torch.stack(batch_token_tensors, dim=0), args.device) ctx_seg_batch = move_to_device(torch.zeros_like(ctx_ids_batch), args.device) ctx_attn_mask = move_to_device(tensorizer.get_attn_mask(ctx_ids_batch), args.device) with torch.no_grad(): _, out, _ = model(ctx_ids_batch, ctx_seg_batch, ctx_attn_mask) out = out.cpu() ctx_ids = [r[0] for r in ctx_rows[batch_start:batch_start + bsz]] assert len(ctx_ids) == out.size(0) total += len(ctx_ids) #results.extend([ # (ctx_ids[i], out[i].view(-1).numpy()) # for i in range(out.size(0)) #]) results.extend([(ctx_ids[i], out[i].numpy()) for i in range(out.size(0))]) if total % 10 == 0: logger.info('Encoded passages %d', total) return results
def _do_biencoder_fwd_pass( model: nn.Module, input: BiEncoderBatch, tensorizer: Tensorizer, cfg, encoder_type: str, rep_positions=0, loss_scale: float = None, ) -> Tuple[torch.Tensor, int]: input = BiEncoderBatch(**move_to_device(input._asdict(), cfg.device)) q_attn_mask = tensorizer.get_attn_mask(input.question_ids) ctx_attn_mask = tensorizer.get_attn_mask(input.context_ids) if model.training: model_out = model( input.question_ids, input.question_segments, q_attn_mask, input.context_ids, input.ctx_segments, ctx_attn_mask, encoder_type=encoder_type, representation_token_pos=rep_positions, ) else: with torch.no_grad(): model_out = model( input.question_ids, input.question_segments, q_attn_mask, input.context_ids, input.ctx_segments, ctx_attn_mask, encoder_type=encoder_type, representation_token_pos=rep_positions, ) local_q_vector, local_ctx_vectors = model_out loss_function = BiEncoderNllLoss() loss, is_correct = _calc_loss( cfg, loss_function, local_q_vector, local_ctx_vectors, input.is_positive, input.hard_negatives, loss_scale=loss_scale, ) is_correct = is_correct.sum().item() if cfg.n_gpu > 1: loss = loss.mean() if cfg.train.gradient_accumulation_steps > 1: loss = loss / cfg.gradient_accumulation_steps return loss, is_correct
def _calc_loss(self, input: ReaderBatch) -> torch.Tensor: args = self.args input = ReaderBatch(**move_to_device(input._asdict(), args.device)) attn_mask = self.tensorizer.get_attn_mask(input.input_ids) questions_num, passages_per_question, _ = input.input_ids.size() if self.reader.training: # start_logits, end_logits, rank_logits = self.reader(input.input_ids, attn_mask) loss = self.reader(input.input_ids, attn_mask, input.start_positions, input.end_positions, input.answers_mask) else: # TODO: remove? with torch.no_grad(): start_logits, end_logits, rank_logits = self.reader( input.input_ids, attn_mask) loss = compute_loss(input.start_positions, input.end_positions, input.answers_mask, start_logits, end_logits, rank_logits, questions_num, passages_per_question) if args.n_gpu > 1: loss = loss.mean() if args.gradient_accumulation_steps > 1: loss = loss / args.gradient_accumulation_steps return loss
def _do_biencoder_fwd_pass(model: nn.Module, input: BiEncoderBatch, tensorizer: Tensorizer, args) -> (torch.Tensor, int): input = BiEncoderBatch(**move_to_device(input._asdict(), args.device)) q_attn_mask = tensorizer.get_attn_mask(input.question_ids) ctx_attn_mask = tensorizer.get_attn_mask(input.context_ids) if model.training: model_out = model(input.question_ids, input.question_segments, q_attn_mask, input.context_ids, input.ctx_segments, ctx_attn_mask) else: with torch.no_grad(): model_out = model(input.question_ids, input.question_segments, q_attn_mask, input.context_ids, input.ctx_segments, ctx_attn_mask) local_q_vector, local_ctx_vectors = model_out loss_function = BiEncoderNllLoss() loss, is_correct = _calc_loss(args, loss_function, local_q_vector, local_ctx_vectors, input.is_positive, input.hard_negatives) is_correct = is_correct.sum().item() if args.n_gpu > 1: loss = loss.mean() if args.gradient_accumulation_steps > 1: loss = loss / args.gradient_accumulation_steps return loss, is_correct
def gen_ctx_vectors( ctx_rows: List[Tuple[object, str, str]], model: nn.Module, tensorizer: Tensorizer, insert_title: bool = True, ) -> List[Tuple[object, np.array]]: n = len(ctx_rows) bsz = args.batch_size total = 0 results = [] for j, batch_start in enumerate(range(0, n, bsz)): batch_token_tensors = [ tensorizer.text_to_tensor(ctx[1], title=ctx[2] if insert_title else None) for ctx in ctx_rows[batch_start : batch_start + bsz] ] ctx_ids_batch = move_to_device( torch.stack(batch_token_tensors, dim=0), args.device ) ctx_seg_batch = move_to_device(torch.zeros_like(ctx_ids_batch), args.device) ctx_attn_mask = move_to_device( tensorizer.get_attn_mask(ctx_ids_batch), args.device ) with torch.no_grad(): _, out, _ = model(ctx_ids_batch, ctx_seg_batch, ctx_attn_mask) out = out.cpu() ctx_ids = [r[0] for r in ctx_rows[batch_start : batch_start + bsz]] assert len(ctx_ids) == out.size(0) total += len(ctx_ids) results.extend( [(ctx_ids[i], out[i].view(-1).numpy()) for i in range(out.size(0))] ) if total % 10 == 0: logger.info("Encoded passages %d", total) return results
def validate(self): logger.info('Validation ...') args = self.args self.reader.eval() data_iterator = self.get_data_iterator(args.dev_file, args.dev_batch_size, False, shuffle=False) log_result_step = args.log_batch_step all_results = [] eval_top_docs = args.eval_top_docs for i, samples_batch in enumerate(data_iterator.iterate_data()): input = create_reader_input(self.tensorizer.get_pad_id(), samples_batch, args.passages_per_question_predict, args.sequence_length, args.max_n_answers, is_train=False, shuffle=False) input = ReaderBatch(**move_to_device(input._asdict(), args.device)) attn_mask = self.tensorizer.get_attn_mask(input.input_ids) with torch.no_grad(): start_logits, end_logits, relevance_logits = self.reader(input.input_ids, attn_mask) batch_predictions = self._get_best_prediction(start_logits, end_logits, relevance_logits, samples_batch, passage_thresholds=eval_top_docs) all_results.extend(batch_predictions) if (i + 1) % log_result_step == 0: logger.info('Eval step: %d ', i) ems = defaultdict(list) for q_predictions in all_results: gold_answers = q_predictions.gold_answers span_predictions = q_predictions.predictions # {top docs threshold -> SpanPrediction()} for (n, span_prediction) in span_predictions.items(): em_hit = max([exact_match_score(span_prediction.prediction_text, ga) for ga in gold_answers]) ems[n].append(em_hit) em = 0 for n in sorted(ems.keys()): em = np.mean(ems[n]) logger.info("n=%d\tEM %.2f" % (n, em * 100)) if args.prediction_results_file: self._save_predictions(args.prediction_results_file, all_results) return em
def _calc_loss(self, input: ReaderBatch) -> torch.Tensor: cfg = self.cfg input = ReaderBatch(**move_to_device(input._asdict(), cfg.device)) attn_mask = self.tensorizer.get_attn_mask(input.input_ids) questions_num, passages_per_question, _ = input.input_ids.size() if self.reader.training: # start_logits, end_logits, rank_logits = self.reader(input.input_ids, attn_mask) loss = self.reader( input.input_ids, attn_mask, input.start_positions, input.end_positions, input.answers_mask, use_simple_loss=getattr(cfg.train, "use_simple_loss", False), average_loss=getattr(cfg.train, "average_loss", False), ) else: # TODO: remove? with torch.no_grad(): start_logits, end_logits, rank_logits = self.reader( input.input_ids, attn_mask) loss = compute_loss( input.start_positions, input.end_positions, input.answers_mask, start_logits, end_logits, rank_logits, questions_num, passages_per_question, use_simple_loss=getattr(cfg.train, "use_simple_loss", False), average=getattr(cfg.train, "average_loss", False), ) if cfg.n_gpu > 1: loss = loss.mean() if cfg.train.gradient_accumulation_steps > 1: loss = loss / cfg.train.gradient_accumulation_steps return loss
def validate_average_rank(self) -> float: """ Validates biencoder model using each question's gold passage's rank across the set of passages from the dataset. It generates vectors for specified amount of negative passages from each question (see --val_av_rank_xxx params) and stores them in RAM as well as question vectors. Then the similarity scores are calculted for the entire num_questions x (num_questions x num_passages_per_question) matrix and sorted per quesrtion. Each question's gold passage rank in that sorted list of scores is averaged across all the questions. :return: averaged rank number """ logger.info('Average rank validation ...') args = self.args self.biencoder.eval() distributed_factor = self.distributed_factor data_iterator = self.get_data_iterator(args.dev_file, args.dev_batch_size, shuffle=False) sub_batch_size = args.val_av_rank_bsz sim_score_f = BiEncoderNllLoss.get_similarity_function() q_represenations = [] ctx_represenations = [] positive_idx_per_question = [] num_hard_negatives = args.val_av_rank_hard_neg num_other_negatives = args.val_av_rank_other_neg log_result_step = args.log_batch_step for i, samples_batch in enumerate(data_iterator.iterate_data()): # samples += 1 if len(q_represenations ) > args.val_av_rank_max_qs / distributed_factor: break biencoder_input = BiEncoder.create_biencoder_input( samples_batch, self.tensorizer, True, num_hard_negatives, num_other_negatives, shuffle=False) total_ctxs = len(ctx_represenations) ctxs_ids = biencoder_input.context_ids ctxs_segments = biencoder_input.ctx_segments bsz = ctxs_ids.size(0) # split contexts batch into sub batches since it is supposed to be too large to be processed in one batch for j, batch_start in enumerate(range(0, bsz, sub_batch_size)): q_ids, q_segments = (biencoder_input.question_ids, biencoder_input.question_segments) if j == 0 \ else (None, None) # notice: change here q_ids = move_to_device(q_ids, args.device) q_segments = move_to_device(q_segments, args.device) if j == 0 and args.n_gpu > 1 and q_ids.size(0) == 1: # if we are in DP (but not in DDP) mode, all model input tensors should have batch size >1 or 0, # otherwise the other input tensors will be split but only the first split will be called continue ctx_ids_batch = move_to_device( ctxs_ids[batch_start:batch_start + sub_batch_size], args.device) ctx_seg_batch = move_to_device( ctxs_segments[batch_start:batch_start + sub_batch_size], args.device) q_attn_mask = move_to_device( self.tensorizer.get_attn_mask(q_ids), args.device) ctx_attn_mask = move_to_device( self.tensorizer.get_attn_mask(ctx_ids_batch), args.device) with torch.no_grad(): q_dense, ctx_dense = self.biencoder( q_ids, q_segments, q_attn_mask, ctx_ids_batch, ctx_seg_batch, ctx_attn_mask) if q_dense is not None: q_represenations.extend(q_dense.cpu().split(1, dim=0)) ctx_represenations.extend(ctx_dense.cpu().split(1, dim=0)) batch_positive_idxs = biencoder_input.is_positive positive_idx_per_question.extend( [total_ctxs + v for v in batch_positive_idxs]) if (i + 1) % log_result_step == 0: logger.info( 'Av.rank validation: step %d, computed ctx_vectors %d, q_vectors %d', i, len(ctx_represenations), len(q_represenations)) ctx_represenations = torch.cat(ctx_represenations, dim=0) q_represenations = torch.cat(q_represenations, dim=0) logger.info('Av.rank validation: total q_vectors size=%s', q_represenations.size()) logger.info('Av.rank validation: total ctx_vectors size=%s', ctx_represenations.size()) q_num = q_represenations.size(0) assert q_num == len(positive_idx_per_question) scores = sim_score_f(q_represenations, ctx_represenations) values, indices = torch.sort(scores, dim=1, descending=True) rank = 0 for i, idx in enumerate(positive_idx_per_question): # aggregate the rank of the known gold passage in the sorted results for each question gold_idx = (indices[i] == idx).nonzero() rank += gold_idx.item() if distributed_factor > 1: # each node calcuated its own rank, exchange the information between node and calculate the "global" average rank # NOTE: the set of passages is still unique for every node eval_stats = all_gather_list([rank, q_num], max_size=100) for i, item in enumerate(eval_stats): remote_rank, remote_q_num = item if i != args.local_rank: rank += remote_rank q_num += remote_q_num av_rank = float(rank / q_num) logger.info('Av.rank validation: average rank %s, total questions=%d', av_rank, q_num) return av_rank
def _do_biencoder_fwd_pass( model: nn.Module, input: BiEncoderBatch, tensorizer: Tensorizer, loss_function, cfg, encoder_type: str, rep_positions_q=0, rep_positions_c=0, loss_scale: float = None, clustering: bool = False, ) -> Tuple[torch.Tensor, int]: input = BiEncoderBatch(**move_to_device(input._asdict(), cfg.device)) q_attn_mask = tensorizer.get_attn_mask(input.question_ids) ctx_attn_mask = tensorizer.get_attn_mask(input.context_ids) if model.training: model_out = model( input.question_ids, input.question_segments, q_attn_mask, input.context_ids, input.ctx_segments, ctx_attn_mask, encoder_type=encoder_type, representation_token_pos_q=rep_positions_q, representation_token_pos_c=rep_positions_c, ) else: with torch.no_grad(): model_out = model( input.question_ids, input.question_segments, q_attn_mask, input.context_ids, input.ctx_segments, ctx_attn_mask, encoder_type=encoder_type, representation_token_pos_q=rep_positions_q, representation_token_pos_c=rep_positions_c, ) local_q_vector, local_ctx_vectors = model_out if cfg.others.is_matching: # MatchBiEncoder model loss, ml_is_correct, matching_is_correct = _calc_loss_matching( cfg, model, loss_function, local_q_vector, local_ctx_vectors, input.is_positive, input.hard_negatives, loss_scale=loss_scale, ) ml_is_correct = ml_is_correct.sum().item() matching_is_correct = matching_is_correct.sum().item() else: loss, is_correct = calc_loss( cfg, loss_function, local_q_vector, local_ctx_vectors, input.is_positive, input.hard_negatives, loss_scale=loss_scale, ) is_correct = is_correct.sum().item() if cfg.n_gpu > 1: loss = loss.mean() if cfg.train.gradient_accumulation_steps > 1: loss = loss / cfg.gradient_accumulation_steps if clustering: assert not cfg.others.is_matching return loss, is_correct, model_out elif cfg.others.is_matching: return loss, ml_is_correct, matching_is_correct else: return loss, is_correct
def validate(self): logger.info("Validation ...") cfg = self.cfg self.reader.eval() if self.dev_iterator is None: self.dev_iterator = self.get_data_iterator( cfg.dev_files, cfg.train.dev_batch_size, False, shuffle=False) log_result_step = cfg.train.log_batch_step // 4 # validation needs to be more verbose all_results = [] eval_top_docs = cfg.eval_top_docs for i, samples_batch in enumerate(self.dev_iterator.iterate_ds_data()): input = create_reader_input( self.wiki_data, self.tensorizer, samples_batch, cfg.passages_per_question_predict, cfg.encoder.sequence_length, cfg.max_n_answers, is_train=False, shuffle=False, ) input = ReaderBatch(**move_to_device(input._asdict(), cfg.device)) attn_mask = self.tensorizer.get_attn_mask(input.input_ids) with torch.no_grad(): start_logits, end_logits, relevance_logits = self.reader( input.input_ids, attn_mask) batch_predictions = get_best_prediction( self.cfg.max_answer_length, self.tensorizer, start_logits, end_logits, relevance_logits, samples_batch, passage_thresholds=eval_top_docs, ) all_results.extend(batch_predictions) if (i + 1) % log_result_step == 0: logger.info("Eval step: %d ", i) ems = defaultdict(list) f1s = defaultdict(list) for q_predictions in all_results: gold_answers = q_predictions.gold_answers span_predictions = (q_predictions.predictions ) # {top docs threshold -> SpanPrediction()} for (n, span_prediction) in span_predictions.items(): # Exact match em_hit = max([ exact_match_score(span_prediction.prediction_text, ga) for ga in gold_answers ]) ems[n].append(em_hit) # F1 score f1_hit = max([ f1_score(span_prediction.prediction_text, ga) for ga in gold_answers ]) f1s[n].append(f1_hit) # Sync between GPUs ems, f1s = gather(self.cfg, [ems, f1s]) em = 0 for n in sorted(ems[0].keys()): ems_n = sum([em[n] for em in ems], []) # gather and concatenate em = np.mean(ems_n) logger.info("n=%d\tEM %.2f" % (n, em * 100)) for n in sorted(f1s[0].keys()): f1s_n = sum([f1[n] for f1 in f1s], []) # gather and concatenate f1 = np.mean(f1s_n) logger.info("n=%d\tF1 %.2f" % (n, f1 * 100)) if cfg.prediction_results_file: self._save_predictions(cfg.prediction_results_file, all_results) return em
def gen_ctx_vectors( cfg: DictConfig, ctx_rows: List[Tuple[object, BiEncoderPassage]], q_rows: List[object], model: nn.Module, tensorizer: Tensorizer, insert_title: bool = True, ) -> List[Tuple[object, np.array]]: n = len(ctx_rows) bsz = cfg.batch_size total = 0 results = [] for j, batch_start in enumerate(range(0, n, bsz)): # Passage preprocess # TODO; max seq length check batch = ctx_rows[batch_start:batch_start + bsz] batch_token_tensors = [ tensorizer.text_to_tensor( ctx[1].text, title=ctx[1].title if insert_title else None) for ctx in batch ] ctx_ids_batch = move_to_device(torch.stack(batch_token_tensors, dim=0), cfg.device) ctx_seg_batch = move_to_device(torch.zeros_like(ctx_ids_batch), cfg.device) ctx_attn_mask = move_to_device(tensorizer.get_attn_mask(ctx_ids_batch), cfg.device) # Question preprocess q_batch = q_rows[batch_start:batch_start + bsz] q_batch_token_tensors = [ tensorizer.text_to_tensor(qq) for qq in q_batch ] q_ids_batch = move_to_device(torch.stack(q_batch_token_tensors, dim=0), cfg.device) q_seg_batch = move_to_device(torch.zeros_like(q_ids_batch), cfg.device) q_attn_mask = move_to_device(tensorizer.get_attn_mask(q_ids_batch), cfg.device) # Selector from dpr.data.biencoder_data import DEFAULT_SELECTOR selector = DEFAULT_SELECTOR rep_positions = selector.get_positions(q_ids_batch, tensorizer) with torch.no_grad(): q_dense, ctx_dense = model( q_ids_batch, q_seg_batch, q_attn_mask, ctx_ids_batch, ctx_seg_batch, ctx_attn_mask, representation_token_pos=rep_positions, ) q_dense = q_dense.cpu() ctx_dense = ctx_dense.cpu() ctx_ids = [r[0] for r in batch] assert len(ctx_ids) == q_dense.size(0) == ctx_dense.size(0) total += len(ctx_ids) results.extend([(ctx_ids[i], q_dense[i].numpy(), ctx_dense[i].numpy(), q_dense[i].numpy().dot(ctx_dense[i].numpy())) for i in range(q_dense.size(0))]) if total % 10 == 0: logger.info("Encoded questions / passages %d", total) # break return results
def do_ofa_fwd_pass( trainer, mode: str, backward: bool, # whether to backward loss step: bool, # whether to perform `optimizer.step()` biencoder_input: BiEncoderBatch, biencoder_config: BiEncoderTrainingConfig, reader_inputs: List[ReaderBatch], reader_config: ReaderTrainingConfig, inference_only: bool = False, ) -> Union[ForwardPassOutputsTrain, BiEncoderPredictionBatch, List[ReaderPredictionBatch], Tuple[BiEncoderPredictionBatch, List[ReaderPredictionBatch]], ]: """ Note: if `inference_only` is set to True: 1. No loss is computed. 2. No backward pass is performed. 3. All predictions are transformed to CPU to save memory. """ assert mode in ["retriever", "reader", "both"], f"Invalid mode: {mode}" if inference_only: assert (not backward) and (not step) and (not trainer.model.training) biencoder_is_correct = None biencoder_preds = None reader_input_tot = None reader_preds_tot = None # Forward pass and backward pass for biencoder if mode in ["retriever", "both"]: biencoder_input = BiEncoderBatch( **move_to_device(biencoder_input._asdict(), trainer.cfg.device)) if trainer.model.training: biencoder_preds: BiEncoderPredictionBatch = trainer.model( mode="retriever", biencoder_batch=biencoder_input, biencoder_config=biencoder_config, reader_batch=None, reader_config=None, ) else: with torch.no_grad(): biencoder_preds: BiEncoderPredictionBatch = trainer.model( mode="retriever", biencoder_batch=biencoder_input, biencoder_config=biencoder_config, reader_batch=None, reader_config=None, ) if not inference_only: # Calculate biencoder loss biencoder_loss, biencoder_is_correct = calc_loss_biencoder( cfg=trainer.cfg, loss_function=trainer.biencoder_loss_function, local_q_vector=biencoder_preds.question_vector, local_ctx_vectors=biencoder_preds.context_vector, local_positive_idxs=biencoder_input.is_positive, local_hard_negatives_idxs=biencoder_input.hard_negatives, loss_scale=None, ) biencoder_is_correct = biencoder_is_correct.sum().item() biencoder_input = BiEncoderBatch( **move_to_device(biencoder_input._asdict(), "cpu")) # Re-calibrate loss if trainer.cfg.n_gpu > 1: biencoder_loss = biencoder_loss.mean() if trainer.cfg.train.gradient_accumulation_steps > 1: biencoder_loss = biencoder_loss / trainer.cfg.gradient_accumulation_steps if backward: assert trainer.model.training, "Model is not in training mode!" trainer.backward( loss=biencoder_loss, optimizer=trainer.biencoder_optimizer, scheduler=trainer.biencoder_scheduler, step=step, ) else: biencoder_input = BiEncoderBatch( **move_to_device(biencoder_input._asdict(), "cpu")) biencoder_preds = BiEncoderPredictionBatch( **move_to_device(biencoder_preds._asdict(), "cpu")) # Forward and backward pass for reader if mode in ["reader", "both"]: reader_total_loss = 0 reader_input_tot: List[ReaderBatch] = [] reader_preds_tot: List[ReaderPredictionBatch] = [] for reader_input in reader_inputs: reader_input = ReaderBatch( **move_to_device(reader_input._asdict(), trainer.cfg.device)) if trainer.model.training: reader_preds: ReaderPredictionBatch = trainer.model( mode="reader", biencoder_batch=None, biencoder_config=None, reader_batch=reader_input, reader_config=reader_config, ) reader_loss = reader_preds.total_loss / len( reader_inputs) # scale by number of sub batches reader_total_loss += reader_loss # Re-calibrate loss if trainer.cfg.n_gpu > 1: reader_loss = reader_loss.mean() if trainer.cfg.train.gradient_accumulation_steps > 1: reader_loss = reader_loss / trainer.cfg.gradient_accumulation_steps if backward: assert trainer.model.training, "Model is not in training mode!" trainer.backward( loss=reader_loss, optimizer=trainer.reader_optimizer, scheduler=trainer.reader_scheduler, step=step, ) else: with torch.no_grad(): reader_preds: ReaderPredictionBatch = trainer.model( mode="reader", biencoder_batch=None, biencoder_config=None, reader_batch=reader_input, reader_config=reader_config, ) if not inference_only: questions_num, passages_per_question, _ = reader_input.input_ids.size( ) reader_total_loss = calc_loss_reader( start_positions=reader_input.start_positions, end_positions=reader_input.end_positions, answers_mask=reader_input.answers_mask, start_logits=reader_preds.start_logits, end_logits=reader_preds.end_logits, relevance_logits=reader_preds.relevance_logits, N=questions_num, M=passages_per_question, use_simple_loss=reader_config.use_simple_loss, average=reader_config.average_loss, ) reader_input = ReaderBatch( **move_to_device(reader_input._asdict(), "cpu")) reader_input_tot.append(reader_input) reader_preds = ReaderPredictionBatch( **move_to_device(reader_preds._asdict(), "cpu")) reader_preds_tot.append(reader_preds) if inference_only: if mode == "retriever": return biencoder_preds elif mode == "reader": return reader_preds_tot else: return biencoder_preds, reader_preds_tot else: # Total loss; for now use 1:1 weights if mode == "retriever": loss = biencoder_loss elif mode == "reader": loss = reader_total_loss else: loss = biencoder_loss + reader_total_loss outputs = ForwardPassOutputsTrain( loss=loss, biencoder_is_correct=biencoder_is_correct, biencoder_input=biencoder_input, biencoder_preds=biencoder_preds, reader_input=reader_input_tot, reader_preds=reader_preds_tot, ) return outputs
def validate(self): logger.info('Validation ...') args = self.args self.reader.eval() data_iterator = self.get_data_iterator(args.dev_file, args.dev_batch_size, False, shuffle=False) log_result_step = args.log_batch_step all_results = [] eval_top_docs = args.eval_top_docs for i, samples_batch in enumerate(data_iterator.iterate_data()): input = create_reader_input(self.tensorizer.get_pad_id(), samples_batch, args.passages_per_question_predict, args.sequence_length, args.max_n_answers, is_train=False, shuffle=False) input = ReaderBatch(**move_to_device(input._asdict(), args.device)) attn_mask = self.tensorizer.get_attn_mask(input.input_ids) with torch.no_grad(): start_logits, end_logits, relevance_logits = self.reader( input.input_ids, attn_mask) batch_predictions = self._get_best_prediction( start_logits, end_logits, relevance_logits, samples_batch, passage_thresholds=eval_top_docs) all_results.extend(batch_predictions) if (i + 1) % log_result_step == 0: logger.info('Eval step: %d ', i) if args.prediction_results_file: self._save_predictions(args.prediction_results_file, all_results) em = 0 # exact match cm = 0 # char match rouge_scorer = Rouge() bleu_scorer = Bleu() if not args.test_only: ems = defaultdict(list) cms = defaultdict(list) gts = defaultdict(list) preds = defaultdict(list) top1 = defaultdict(list) for q_predictions in all_results: gold_answers = q_predictions.gold_answers span_predictions = q_predictions.predictions # {top docs threshold -> SpanPrediction()} for (n, span_prediction) in span_predictions.items(): em_hit = max([ exact_match_score(span_prediction.prediction_text, ga) for ga in gold_answers ]) cm_hit = max([ char_match_score(span_prediction.prediction_text, ga) for ga in gold_answers ]) ems[n].append(em_hit) cms[n].append(cm_hit) # for bleu/rouge later gts[n].append(gold_answers) preds[n].append(span_prediction.prediction_text) # for qa_classify top1 has_answer = q_predictions.passages_has_answer[ span_prediction.passage_index] top1[n].append(float(has_answer)) for n in sorted(ems.keys()): em = np.mean(ems[n]) cm = np.mean(cms[n]) bleu = bleu_scorer.compute_score(gts[n], preds[n]) rouge = rouge_scorer.compute_score(gts[n], preds[n]) t1 = np.mean(top1[n]) mean_score = (em + cm) / 2 logger.info( "n=%d\tEM %.2f\tCM %.2f\tScore %.2f\tTop-1 %.2f\n" % (n, em * 100, cm * 100, mean_score * 100, t1 * 100)) # logger.info("n=%d\tEM %.2f\tCM %.2f\tRouge-L %.2f\tBLEU-4 %.2f\tTop-1 %.2f\n" % (n, em * 100, cm * 100, rouge * 100, bleu * 100, t1 * 100)) return em