def create_ofa_input( mode: str, wiki_data: TokenizedWikipediaPassages, tensorizer: Tensorizer, samples: List[Tuple[BiEncoderSampleTokenized, ReaderSample]], biencoder_config: BiEncoderDataConfig, reader_config: ReaderDataConfig, ) -> Union[BiEncoderBatch, List[ReaderBatch], Tuple[BiEncoderBatch, List[ReaderBatch]], ]: assert mode in ["retriever", "reader", "both"], f"Invalid mode: {mode}" retriever_samples, reader_samples = zip(*samples) # Retriever (bi-encoder) if mode in ["retriever", "both"]: biencoder_batch = BiEncoder.create_biencoder_input( samples=retriever_samples, tensorizer=tensorizer, insert_title=biencoder_config.insert_title, num_hard_negatives=biencoder_config.num_hard_negatives, num_other_negatives=biencoder_config.num_other_negatives, shuffle=biencoder_config.shuffle, shuffle_positives=biencoder_config.shuffle_positives, hard_neg_fallback=biencoder_config.hard_neg_fallback, query_token=biencoder_config.query_token, ) # Reader if mode in ["reader", "both"]: num_samples = len(samples) num_sub_batches = reader_config.num_sub_batches assert num_sub_batches > 0 sub_batch_size = math.ceil(num_samples / num_sub_batches) reader_batches: List[ReaderBatch] = [] for batch_i in range(num_sub_batches): start = batch_i * sub_batch_size end = min(start + sub_batch_size, num_samples) if start >= end: break reader_batch = create_reader_input( wiki_data=wiki_data, tensorizer=tensorizer, samples=reader_samples[start:end], passages_per_question=reader_config.passages_per_question, max_length=reader_config.max_length, max_n_answers=reader_config.max_n_answers, is_train=reader_config.is_train, shuffle=reader_config.shuffle, ) reader_batches.append(reader_batch) if mode == "retriever": return biencoder_batch elif mode == "reader": return reader_batches else: return biencoder_batch, reader_batches
def validate(self): logger.info('Validation ...') args = self.args self.reader.eval() data_iterator = self.get_data_iterator(args.dev_file, args.dev_batch_size, False, shuffle=False) log_result_step = args.log_batch_step all_results = [] eval_top_docs = args.eval_top_docs for i, samples_batch in enumerate(data_iterator.iterate_data()): input = create_reader_input(self.tensorizer.get_pad_id(), samples_batch, args.passages_per_question_predict, args.sequence_length, args.max_n_answers, is_train=False, shuffle=False) input = ReaderBatch(**move_to_device(input._asdict(), args.device)) attn_mask = self.tensorizer.get_attn_mask(input.input_ids) with torch.no_grad(): start_logits, end_logits, relevance_logits = self.reader(input.input_ids, attn_mask) batch_predictions = self._get_best_prediction(start_logits, end_logits, relevance_logits, samples_batch, passage_thresholds=eval_top_docs) all_results.extend(batch_predictions) if (i + 1) % log_result_step == 0: logger.info('Eval step: %d ', i) ems = defaultdict(list) for q_predictions in all_results: gold_answers = q_predictions.gold_answers span_predictions = q_predictions.predictions # {top docs threshold -> SpanPrediction()} for (n, span_prediction) in span_predictions.items(): em_hit = max([exact_match_score(span_prediction.prediction_text, ga) for ga in gold_answers]) ems[n].append(em_hit) em = 0 for n in sorted(ems.keys()): em = np.mean(ems[n]) logger.info("n=%d\tEM %.2f" % (n, em * 100)) if args.prediction_results_file: self._save_predictions(args.prediction_results_file, all_results) return em
def _train_epoch( self, scheduler, epoch: int, eval_step: int, train_data_iterator: ShardedDataIterator, global_step: int, ): cfg = self.cfg rolling_train_loss = 0.0 epoch_loss = 0 log_result_step = cfg.train.log_batch_step rolling_loss_step = cfg.train.train_rolling_loss_step self.reader.train() epoch_batches = train_data_iterator.max_iterations for i, samples_batch in enumerate( train_data_iterator.iterate_ds_data(epoch=epoch)): data_iteration = train_data_iterator.get_iteration() # enables to resume to exactly same train state if cfg.fully_resumable: np.random.seed(cfg.seed + global_step) torch.manual_seed(cfg.seed + global_step) if cfg.n_gpu > 0: torch.cuda.manual_seed_all(cfg.seed + global_step) input = create_reader_input( self.tensorizer.get_pad_id(), samples_batch, cfg.passages_per_question, cfg.encoder.sequence_length, cfg.max_n_answers, is_train=True, shuffle=True, ) loss = self._calc_loss(input) epoch_loss += loss.item() rolling_train_loss += loss.item() max_grad_norm = cfg.train.max_grad_norm if cfg.fp16: from apex import amp with amp.scale_loss(loss, self.optimizer) as scaled_loss: scaled_loss.backward() if max_grad_norm > 0: torch.nn.utils.clip_grad_norm_( amp.master_params(self.optimizer), max_grad_norm) else: loss.backward() if max_grad_norm > 0: torch.nn.utils.clip_grad_norm_(self.reader.parameters(), max_grad_norm) if (i + 1) % cfg.train.gradient_accumulation_steps == 0: self.optimizer.step() scheduler.step() self.reader.zero_grad() global_step += 1 if i % log_result_step == 0: lr = self.optimizer.param_groups[0]["lr"] logger.info( "Epoch: %d: Step: %d/%d, global_step=%d, lr=%f", epoch, data_iteration, epoch_batches, global_step, lr, ) if (i + 1) % rolling_loss_step == 0: logger.info("Train batch %d", data_iteration) latest_rolling_train_av_loss = rolling_train_loss / rolling_loss_step logger.info( "Avg. loss per last %d batches: %f", rolling_loss_step, latest_rolling_train_av_loss, ) rolling_train_loss = 0.0 if global_step % eval_step == 0: logger.info( "Validation: Epoch: %d Step: %d/%d", epoch, data_iteration, epoch_batches, ) self.validate_and_save(epoch, train_data_iterator.get_iteration(), scheduler) self.reader.train() epoch_loss = (epoch_loss / epoch_batches) if epoch_batches > 0 else 0 logger.info("Av Loss per epoch=%f", epoch_loss) return global_step
def validate(self): logger.info("Validation ...") cfg = self.cfg self.reader.eval() if self.dev_iterator is None: self.dev_iterator = self.get_data_iterator( cfg.dev_files, cfg.train.dev_batch_size, False, shuffle=False) log_result_step = cfg.train.log_batch_step // 4 # validation needs to be more verbose all_results = [] eval_top_docs = cfg.eval_top_docs for i, samples_batch in enumerate(self.dev_iterator.iterate_ds_data()): input = create_reader_input( self.wiki_data, self.tensorizer, samples_batch, cfg.passages_per_question_predict, cfg.encoder.sequence_length, cfg.max_n_answers, is_train=False, shuffle=False, ) input = ReaderBatch(**move_to_device(input._asdict(), cfg.device)) attn_mask = self.tensorizer.get_attn_mask(input.input_ids) with torch.no_grad(): start_logits, end_logits, relevance_logits = self.reader( input.input_ids, attn_mask) batch_predictions = get_best_prediction( self.cfg.max_answer_length, self.tensorizer, start_logits, end_logits, relevance_logits, samples_batch, passage_thresholds=eval_top_docs, ) all_results.extend(batch_predictions) if (i + 1) % log_result_step == 0: logger.info("Eval step: %d ", i) ems = defaultdict(list) f1s = defaultdict(list) for q_predictions in all_results: gold_answers = q_predictions.gold_answers span_predictions = (q_predictions.predictions ) # {top docs threshold -> SpanPrediction()} for (n, span_prediction) in span_predictions.items(): # Exact match em_hit = max([ exact_match_score(span_prediction.prediction_text, ga) for ga in gold_answers ]) ems[n].append(em_hit) # F1 score f1_hit = max([ f1_score(span_prediction.prediction_text, ga) for ga in gold_answers ]) f1s[n].append(f1_hit) # Sync between GPUs ems, f1s = gather(self.cfg, [ems, f1s]) em = 0 for n in sorted(ems[0].keys()): ems_n = sum([em[n] for em in ems], []) # gather and concatenate em = np.mean(ems_n) logger.info("n=%d\tEM %.2f" % (n, em * 100)) for n in sorted(f1s[0].keys()): f1s_n = sum([f1[n] for f1 in f1s], []) # gather and concatenate f1 = np.mean(f1s_n) logger.info("n=%d\tF1 %.2f" % (n, f1 * 100)) if cfg.prediction_results_file: self._save_predictions(cfg.prediction_results_file, all_results) return em
def validate(self): logger.info('Validation ...') args = self.args self.reader.eval() data_iterator = self.get_data_iterator(args.dev_file, args.dev_batch_size, False, shuffle=False) log_result_step = args.log_batch_step all_results = [] eval_top_docs = args.eval_top_docs for i, samples_batch in enumerate(data_iterator.iterate_data()): input = create_reader_input(self.tensorizer.get_pad_id(), samples_batch, args.passages_per_question_predict, args.sequence_length, args.max_n_answers, is_train=False, shuffle=False) input = ReaderBatch(**move_to_device(input._asdict(), args.device)) attn_mask = self.tensorizer.get_attn_mask(input.input_ids) with torch.no_grad(): start_logits, end_logits, relevance_logits = self.reader( input.input_ids, attn_mask) batch_predictions = self._get_best_prediction( start_logits, end_logits, relevance_logits, samples_batch, passage_thresholds=eval_top_docs) all_results.extend(batch_predictions) if (i + 1) % log_result_step == 0: logger.info('Eval step: %d ', i) if args.prediction_results_file: self._save_predictions(args.prediction_results_file, all_results) em = 0 # exact match cm = 0 # char match rouge_scorer = Rouge() bleu_scorer = Bleu() if not args.test_only: ems = defaultdict(list) cms = defaultdict(list) gts = defaultdict(list) preds = defaultdict(list) top1 = defaultdict(list) for q_predictions in all_results: gold_answers = q_predictions.gold_answers span_predictions = q_predictions.predictions # {top docs threshold -> SpanPrediction()} for (n, span_prediction) in span_predictions.items(): em_hit = max([ exact_match_score(span_prediction.prediction_text, ga) for ga in gold_answers ]) cm_hit = max([ char_match_score(span_prediction.prediction_text, ga) for ga in gold_answers ]) ems[n].append(em_hit) cms[n].append(cm_hit) # for bleu/rouge later gts[n].append(gold_answers) preds[n].append(span_prediction.prediction_text) # for qa_classify top1 has_answer = q_predictions.passages_has_answer[ span_prediction.passage_index] top1[n].append(float(has_answer)) for n in sorted(ems.keys()): em = np.mean(ems[n]) cm = np.mean(cms[n]) bleu = bleu_scorer.compute_score(gts[n], preds[n]) rouge = rouge_scorer.compute_score(gts[n], preds[n]) t1 = np.mean(top1[n]) mean_score = (em + cm) / 2 logger.info( "n=%d\tEM %.2f\tCM %.2f\tScore %.2f\tTop-1 %.2f\n" % (n, em * 100, cm * 100, mean_score * 100, t1 * 100)) # logger.info("n=%d\tEM %.2f\tCM %.2f\tRouge-L %.2f\tBLEU-4 %.2f\tTop-1 %.2f\n" % (n, em * 100, cm * 100, rouge * 100, bleu * 100, t1 * 100)) return em