Esempio n. 1
0
def create_ofa_input(
    mode: str,
    wiki_data: TokenizedWikipediaPassages,
    tensorizer: Tensorizer,
    samples: List[Tuple[BiEncoderSampleTokenized, ReaderSample]],
    biencoder_config: BiEncoderDataConfig,
    reader_config: ReaderDataConfig,
) -> Union[BiEncoderBatch, List[ReaderBatch], Tuple[BiEncoderBatch,
                                                    List[ReaderBatch]], ]:

    assert mode in ["retriever", "reader", "both"], f"Invalid mode: {mode}"
    retriever_samples, reader_samples = zip(*samples)

    # Retriever (bi-encoder)
    if mode in ["retriever", "both"]:
        biencoder_batch = BiEncoder.create_biencoder_input(
            samples=retriever_samples,
            tensorizer=tensorizer,
            insert_title=biencoder_config.insert_title,
            num_hard_negatives=biencoder_config.num_hard_negatives,
            num_other_negatives=biencoder_config.num_other_negatives,
            shuffle=biencoder_config.shuffle,
            shuffle_positives=biencoder_config.shuffle_positives,
            hard_neg_fallback=biencoder_config.hard_neg_fallback,
            query_token=biencoder_config.query_token,
        )

    # Reader
    if mode in ["reader", "both"]:
        num_samples = len(samples)
        num_sub_batches = reader_config.num_sub_batches
        assert num_sub_batches > 0

        sub_batch_size = math.ceil(num_samples / num_sub_batches)
        reader_batches: List[ReaderBatch] = []

        for batch_i in range(num_sub_batches):
            start = batch_i * sub_batch_size
            end = min(start + sub_batch_size, num_samples)
            if start >= end:
                break

            reader_batch = create_reader_input(
                wiki_data=wiki_data,
                tensorizer=tensorizer,
                samples=reader_samples[start:end],
                passages_per_question=reader_config.passages_per_question,
                max_length=reader_config.max_length,
                max_n_answers=reader_config.max_n_answers,
                is_train=reader_config.is_train,
                shuffle=reader_config.shuffle,
            )
            reader_batches.append(reader_batch)

    if mode == "retriever":
        return biencoder_batch
    elif mode == "reader":
        return reader_batches
    else:
        return biencoder_batch, reader_batches
Esempio n. 2
0
    def validate(self):
        logger.info('Validation ...')
        args = self.args
        self.reader.eval()
        data_iterator = self.get_data_iterator(args.dev_file, args.dev_batch_size, False, shuffle=False)

        log_result_step = args.log_batch_step
        all_results = []

        eval_top_docs = args.eval_top_docs
        for i, samples_batch in enumerate(data_iterator.iterate_data()):
            input = create_reader_input(self.tensorizer.get_pad_id(),
                                        samples_batch,
                                        args.passages_per_question_predict,
                                        args.sequence_length,
                                        args.max_n_answers,
                                        is_train=False, shuffle=False)

            input = ReaderBatch(**move_to_device(input._asdict(), args.device))
            attn_mask = self.tensorizer.get_attn_mask(input.input_ids)

            with torch.no_grad():
                start_logits, end_logits, relevance_logits = self.reader(input.input_ids, attn_mask)

            batch_predictions = self._get_best_prediction(start_logits, end_logits, relevance_logits, samples_batch,
                                                          passage_thresholds=eval_top_docs)

            all_results.extend(batch_predictions)

            if (i + 1) % log_result_step == 0:
                logger.info('Eval step: %d ', i)

        ems = defaultdict(list)

        for q_predictions in all_results:
            gold_answers = q_predictions.gold_answers
            span_predictions = q_predictions.predictions  # {top docs threshold -> SpanPrediction()}
            for (n, span_prediction) in span_predictions.items():
                em_hit = max([exact_match_score(span_prediction.prediction_text, ga) for ga in gold_answers])
                ems[n].append(em_hit)
        em = 0
        for n in sorted(ems.keys()):
            em = np.mean(ems[n])
            logger.info("n=%d\tEM %.2f" % (n, em * 100))

        if args.prediction_results_file:
            self._save_predictions(args.prediction_results_file, all_results)

        return em
Esempio n. 3
0
    def _train_epoch(
        self,
        scheduler,
        epoch: int,
        eval_step: int,
        train_data_iterator: ShardedDataIterator,
        global_step: int,
    ):
        cfg = self.cfg
        rolling_train_loss = 0.0
        epoch_loss = 0
        log_result_step = cfg.train.log_batch_step
        rolling_loss_step = cfg.train.train_rolling_loss_step

        self.reader.train()
        epoch_batches = train_data_iterator.max_iterations

        for i, samples_batch in enumerate(
                train_data_iterator.iterate_ds_data(epoch=epoch)):

            data_iteration = train_data_iterator.get_iteration()

            # enables to resume to exactly same train state
            if cfg.fully_resumable:
                np.random.seed(cfg.seed + global_step)
                torch.manual_seed(cfg.seed + global_step)
                if cfg.n_gpu > 0:
                    torch.cuda.manual_seed_all(cfg.seed + global_step)

            input = create_reader_input(
                self.tensorizer.get_pad_id(),
                samples_batch,
                cfg.passages_per_question,
                cfg.encoder.sequence_length,
                cfg.max_n_answers,
                is_train=True,
                shuffle=True,
            )

            loss = self._calc_loss(input)

            epoch_loss += loss.item()
            rolling_train_loss += loss.item()

            max_grad_norm = cfg.train.max_grad_norm
            if cfg.fp16:
                from apex import amp

                with amp.scale_loss(loss, self.optimizer) as scaled_loss:
                    scaled_loss.backward()
                if max_grad_norm > 0:
                    torch.nn.utils.clip_grad_norm_(
                        amp.master_params(self.optimizer), max_grad_norm)
            else:
                loss.backward()
                if max_grad_norm > 0:
                    torch.nn.utils.clip_grad_norm_(self.reader.parameters(),
                                                   max_grad_norm)

            if (i + 1) % cfg.train.gradient_accumulation_steps == 0:
                self.optimizer.step()
                scheduler.step()
                self.reader.zero_grad()
                global_step += 1

            if i % log_result_step == 0:
                lr = self.optimizer.param_groups[0]["lr"]
                logger.info(
                    "Epoch: %d: Step: %d/%d, global_step=%d, lr=%f",
                    epoch,
                    data_iteration,
                    epoch_batches,
                    global_step,
                    lr,
                )

            if (i + 1) % rolling_loss_step == 0:
                logger.info("Train batch %d", data_iteration)
                latest_rolling_train_av_loss = rolling_train_loss / rolling_loss_step
                logger.info(
                    "Avg. loss per last %d batches: %f",
                    rolling_loss_step,
                    latest_rolling_train_av_loss,
                )
                rolling_train_loss = 0.0

            if global_step % eval_step == 0:
                logger.info(
                    "Validation: Epoch: %d Step: %d/%d",
                    epoch,
                    data_iteration,
                    epoch_batches,
                )
                self.validate_and_save(epoch,
                                       train_data_iterator.get_iteration(),
                                       scheduler)
                self.reader.train()

        epoch_loss = (epoch_loss / epoch_batches) if epoch_batches > 0 else 0
        logger.info("Av Loss per epoch=%f", epoch_loss)
        return global_step
Esempio n. 4
0
    def validate(self):
        logger.info("Validation ...")
        cfg = self.cfg
        self.reader.eval()
        if self.dev_iterator is None:
            self.dev_iterator = self.get_data_iterator(
                cfg.dev_files, cfg.train.dev_batch_size, False, shuffle=False)

        log_result_step = cfg.train.log_batch_step // 4  # validation needs to be more verbose
        all_results = []

        eval_top_docs = cfg.eval_top_docs
        for i, samples_batch in enumerate(self.dev_iterator.iterate_ds_data()):
            input = create_reader_input(
                self.wiki_data,
                self.tensorizer,
                samples_batch,
                cfg.passages_per_question_predict,
                cfg.encoder.sequence_length,
                cfg.max_n_answers,
                is_train=False,
                shuffle=False,
            )

            input = ReaderBatch(**move_to_device(input._asdict(), cfg.device))
            attn_mask = self.tensorizer.get_attn_mask(input.input_ids)

            with torch.no_grad():
                start_logits, end_logits, relevance_logits = self.reader(
                    input.input_ids, attn_mask)

            batch_predictions = get_best_prediction(
                self.cfg.max_answer_length,
                self.tensorizer,
                start_logits,
                end_logits,
                relevance_logits,
                samples_batch,
                passage_thresholds=eval_top_docs,
            )

            all_results.extend(batch_predictions)

            if (i + 1) % log_result_step == 0:
                logger.info("Eval step: %d ", i)

        ems = defaultdict(list)
        f1s = defaultdict(list)

        for q_predictions in all_results:
            gold_answers = q_predictions.gold_answers
            span_predictions = (q_predictions.predictions
                                )  # {top docs threshold -> SpanPrediction()}
            for (n, span_prediction) in span_predictions.items():
                # Exact match
                em_hit = max([
                    exact_match_score(span_prediction.prediction_text, ga)
                    for ga in gold_answers
                ])
                ems[n].append(em_hit)

                # F1 score
                f1_hit = max([
                    f1_score(span_prediction.prediction_text, ga)
                    for ga in gold_answers
                ])
                f1s[n].append(f1_hit)

        # Sync between GPUs
        ems, f1s = gather(self.cfg, [ems, f1s])

        em = 0
        for n in sorted(ems[0].keys()):
            ems_n = sum([em[n] for em in ems], [])  # gather and concatenate
            em = np.mean(ems_n)
            logger.info("n=%d\tEM %.2f" % (n, em * 100))

        for n in sorted(f1s[0].keys()):
            f1s_n = sum([f1[n] for f1 in f1s], [])  # gather and concatenate
            f1 = np.mean(f1s_n)
            logger.info("n=%d\tF1 %.2f" % (n, f1 * 100))

        if cfg.prediction_results_file:
            self._save_predictions(cfg.prediction_results_file, all_results)

        return em
Esempio n. 5
0
    def validate(self):
        logger.info('Validation ...')
        args = self.args
        self.reader.eval()
        data_iterator = self.get_data_iterator(args.dev_file,
                                               args.dev_batch_size,
                                               False,
                                               shuffle=False)

        log_result_step = args.log_batch_step
        all_results = []

        eval_top_docs = args.eval_top_docs
        for i, samples_batch in enumerate(data_iterator.iterate_data()):
            input = create_reader_input(self.tensorizer.get_pad_id(),
                                        samples_batch,
                                        args.passages_per_question_predict,
                                        args.sequence_length,
                                        args.max_n_answers,
                                        is_train=False,
                                        shuffle=False)

            input = ReaderBatch(**move_to_device(input._asdict(), args.device))
            attn_mask = self.tensorizer.get_attn_mask(input.input_ids)

            with torch.no_grad():
                start_logits, end_logits, relevance_logits = self.reader(
                    input.input_ids, attn_mask)

            batch_predictions = self._get_best_prediction(
                start_logits,
                end_logits,
                relevance_logits,
                samples_batch,
                passage_thresholds=eval_top_docs)

            all_results.extend(batch_predictions)

            if (i + 1) % log_result_step == 0:
                logger.info('Eval step: %d ', i)

        if args.prediction_results_file:
            self._save_predictions(args.prediction_results_file, all_results)

        em = 0  # exact match
        cm = 0  # char match
        rouge_scorer = Rouge()
        bleu_scorer = Bleu()
        if not args.test_only:
            ems = defaultdict(list)
            cms = defaultdict(list)
            gts = defaultdict(list)
            preds = defaultdict(list)
            top1 = defaultdict(list)

            for q_predictions in all_results:
                gold_answers = q_predictions.gold_answers
                span_predictions = q_predictions.predictions  # {top docs threshold -> SpanPrediction()}
                for (n, span_prediction) in span_predictions.items():
                    em_hit = max([
                        exact_match_score(span_prediction.prediction_text, ga)
                        for ga in gold_answers
                    ])
                    cm_hit = max([
                        char_match_score(span_prediction.prediction_text, ga)
                        for ga in gold_answers
                    ])
                    ems[n].append(em_hit)
                    cms[n].append(cm_hit)
                    # for bleu/rouge later
                    gts[n].append(gold_answers)
                    preds[n].append(span_prediction.prediction_text)
                    # for qa_classify top1
                    has_answer = q_predictions.passages_has_answer[
                        span_prediction.passage_index]
                    top1[n].append(float(has_answer))

            for n in sorted(ems.keys()):
                em = np.mean(ems[n])
                cm = np.mean(cms[n])
                bleu = bleu_scorer.compute_score(gts[n], preds[n])
                rouge = rouge_scorer.compute_score(gts[n], preds[n])
                t1 = np.mean(top1[n])
                mean_score = (em + cm) / 2
                logger.info(
                    "n=%d\tEM %.2f\tCM %.2f\tScore %.2f\tTop-1 %.2f\n" %
                    (n, em * 100, cm * 100, mean_score * 100, t1 * 100))
                # logger.info("n=%d\tEM %.2f\tCM %.2f\tRouge-L %.2f\tBLEU-4 %.2f\tTop-1 %.2f\n" % (n, em * 100, cm * 100, rouge * 100, bleu * 100, t1 * 100))

        return em