Exemplo n.º 1
0
def gen_ctx_vectors(
    cfg: DictConfig,
    ctx_rows: List[Tuple[object, BiEncoderPassage]],
    model: nn.Module,
    tensorizer: Tensorizer,
    insert_title: bool = True,
) -> List[Tuple[object, np.array]]:
    n = len(ctx_rows)
    bsz = cfg.batch_size
    total = 0
    results = []
    for j, batch_start in enumerate(range(0, n, bsz)):
        batch = ctx_rows[batch_start : batch_start + bsz]
        batch_token_tensors = [
            tensorizer.text_to_tensor(
                ctx[1].text, title=ctx[1].title if insert_title else None
            )
            for ctx in batch
        ]

        ctx_ids_batch = move_to_device(
            torch.stack(batch_token_tensors, dim=0), cfg.device
        )
        ctx_seg_batch = move_to_device(torch.zeros_like(ctx_ids_batch), cfg.device)
        ctx_attn_mask = move_to_device(
            tensorizer.get_attn_mask(ctx_ids_batch), cfg.device
        )
        with torch.no_grad():
            _, out, _ = model(ctx_ids_batch, ctx_seg_batch, ctx_attn_mask)
        out = out.cpu()

        ctx_ids = [r[0] for r in batch]
        extra_info = []
        if len(batch[0]) > 3:
            extra_info = [r[3:] for r in batch]

        assert len(ctx_ids) == out.size(0)
        total += len(ctx_ids)

        # TODO: refactor to avoid 'if'
        if extra_info:
            results.extend(
                [
                    (ctx_ids[i], out[i].view(-1).numpy(), *extra_info[i])
                    for i in range(out.size(0))
                ]
            )
        else:
            results.extend(
                [(ctx_ids[i], out[i].view(-1).numpy()) for i in range(out.size(0))]
            )

        if total % 10 == 0:
            logger.info("Encoded passages %d", total)
    return results
Exemplo n.º 2
0
def gen_ctx_vectors(
        ctx_rows: List[Tuple[object, str, str]],
        model: nn.Module,
        tensorizer: Tensorizer,
        insert_title: bool = True) -> List[Tuple[object, np.array]]:
    n = len(ctx_rows)
    bsz = args.batch_size
    total = 0
    results = []
    for j, batch_start in enumerate(range(0, n, bsz)):

        all_txt = []
        for ctx in ctx_rows[batch_start:batch_start + bsz]:
            if ctx[2]:
                txt = ['title:', ctx[2], 'context:', ctx[1]]
            else:
                txt = ['context:', ctx[1]]
            txt = ' '.join(txt)
            all_txt.append(txt)
        batch_token_tensors = [
            tensorizer.text_to_tensor(txt, max_length=250) for txt in all_txt
        ]
        #batch_token_tensors = [tensorizer.text_to_tensor(ctx[1], title=ctx[2] if insert_title else None) for ctx in #original
        #                       ctx_rows[batch_start:batch_start + bsz]]                                             #original

        ctx_ids_batch = move_to_device(torch.stack(batch_token_tensors, dim=0),
                                       args.device)
        ctx_seg_batch = move_to_device(torch.zeros_like(ctx_ids_batch),
                                       args.device)
        ctx_attn_mask = move_to_device(tensorizer.get_attn_mask(ctx_ids_batch),
                                       args.device)
        with torch.no_grad():
            _, out, _ = model(ctx_ids_batch, ctx_seg_batch, ctx_attn_mask)
        out = out.cpu()

        ctx_ids = [r[0] for r in ctx_rows[batch_start:batch_start + bsz]]

        assert len(ctx_ids) == out.size(0)

        total += len(ctx_ids)

        #results.extend([
        #    (ctx_ids[i], out[i].view(-1).numpy())
        #    for i in range(out.size(0))
        #])

        results.extend([(ctx_ids[i], out[i].numpy())
                        for i in range(out.size(0))])

        if total % 10 == 0:
            logger.info('Encoded passages %d', total)

    return results
Exemplo n.º 3
0
def _do_biencoder_fwd_pass(
    model: nn.Module,
    input: BiEncoderBatch,
    tensorizer: Tensorizer,
    cfg,
    encoder_type: str,
    rep_positions=0,
    loss_scale: float = None,
) -> Tuple[torch.Tensor, int]:

    input = BiEncoderBatch(**move_to_device(input._asdict(), cfg.device))

    q_attn_mask = tensorizer.get_attn_mask(input.question_ids)
    ctx_attn_mask = tensorizer.get_attn_mask(input.context_ids)

    if model.training:
        model_out = model(
            input.question_ids,
            input.question_segments,
            q_attn_mask,
            input.context_ids,
            input.ctx_segments,
            ctx_attn_mask,
            encoder_type=encoder_type,
            representation_token_pos=rep_positions,
        )
    else:
        with torch.no_grad():
            model_out = model(
                input.question_ids,
                input.question_segments,
                q_attn_mask,
                input.context_ids,
                input.ctx_segments,
                ctx_attn_mask,
                encoder_type=encoder_type,
                representation_token_pos=rep_positions,
            )

    local_q_vector, local_ctx_vectors = model_out

    loss_function = BiEncoderNllLoss()

    loss, is_correct = _calc_loss(
        cfg,
        loss_function,
        local_q_vector,
        local_ctx_vectors,
        input.is_positive,
        input.hard_negatives,
        loss_scale=loss_scale,
    )

    is_correct = is_correct.sum().item()

    if cfg.n_gpu > 1:
        loss = loss.mean()
    if cfg.train.gradient_accumulation_steps > 1:
        loss = loss / cfg.gradient_accumulation_steps
    return loss, is_correct
Exemplo n.º 4
0
    def _calc_loss(self, input: ReaderBatch) -> torch.Tensor:
        args = self.args
        input = ReaderBatch(**move_to_device(input._asdict(), args.device))
        attn_mask = self.tensorizer.get_attn_mask(input.input_ids)
        questions_num, passages_per_question, _ = input.input_ids.size()

        if self.reader.training:
            # start_logits, end_logits, rank_logits = self.reader(input.input_ids, attn_mask)
            loss = self.reader(input.input_ids, attn_mask,
                               input.start_positions, input.end_positions,
                               input.answers_mask)

        else:
            # TODO: remove?
            with torch.no_grad():
                start_logits, end_logits, rank_logits = self.reader(
                    input.input_ids, attn_mask)

            loss = compute_loss(input.start_positions, input.end_positions,
                                input.answers_mask, start_logits, end_logits,
                                rank_logits, questions_num,
                                passages_per_question)
        if args.n_gpu > 1:
            loss = loss.mean()
        if args.gradient_accumulation_steps > 1:
            loss = loss / args.gradient_accumulation_steps

        return loss
Exemplo n.º 5
0
def _do_biencoder_fwd_pass(model: nn.Module, input: BiEncoderBatch,
                           tensorizer: Tensorizer,
                           args) -> (torch.Tensor, int):
    input = BiEncoderBatch(**move_to_device(input._asdict(), args.device))

    q_attn_mask = tensorizer.get_attn_mask(input.question_ids)
    ctx_attn_mask = tensorizer.get_attn_mask(input.context_ids)

    if model.training:
        model_out = model(input.question_ids, input.question_segments,
                          q_attn_mask, input.context_ids, input.ctx_segments,
                          ctx_attn_mask)
    else:
        with torch.no_grad():
            model_out = model(input.question_ids, input.question_segments,
                              q_attn_mask, input.context_ids,
                              input.ctx_segments, ctx_attn_mask)

    local_q_vector, local_ctx_vectors = model_out

    loss_function = BiEncoderNllLoss()

    loss, is_correct = _calc_loss(args, loss_function, local_q_vector,
                                  local_ctx_vectors, input.is_positive,
                                  input.hard_negatives)

    is_correct = is_correct.sum().item()

    if args.n_gpu > 1:
        loss = loss.mean()
    if args.gradient_accumulation_steps > 1:
        loss = loss / args.gradient_accumulation_steps

    return loss, is_correct
Exemplo n.º 6
0
def gen_ctx_vectors(
    ctx_rows: List[Tuple[object, str, str]],
    model: nn.Module,
    tensorizer: Tensorizer,
    insert_title: bool = True,
) -> List[Tuple[object, np.array]]:
    n = len(ctx_rows)
    bsz = args.batch_size
    total = 0
    results = []
    for j, batch_start in enumerate(range(0, n, bsz)):

        batch_token_tensors = [
            tensorizer.text_to_tensor(ctx[1], title=ctx[2] if insert_title else None)
            for ctx in ctx_rows[batch_start : batch_start + bsz]
        ]

        ctx_ids_batch = move_to_device(
            torch.stack(batch_token_tensors, dim=0), args.device
        )
        ctx_seg_batch = move_to_device(torch.zeros_like(ctx_ids_batch), args.device)
        ctx_attn_mask = move_to_device(
            tensorizer.get_attn_mask(ctx_ids_batch), args.device
        )
        with torch.no_grad():
            _, out, _ = model(ctx_ids_batch, ctx_seg_batch, ctx_attn_mask)
        out = out.cpu()

        ctx_ids = [r[0] for r in ctx_rows[batch_start : batch_start + bsz]]

        assert len(ctx_ids) == out.size(0)

        total += len(ctx_ids)

        results.extend(
            [(ctx_ids[i], out[i].view(-1).numpy()) for i in range(out.size(0))]
        )

        if total % 10 == 0:
            logger.info("Encoded passages %d", total)

    return results
Exemplo n.º 7
0
    def validate(self):
        logger.info('Validation ...')
        args = self.args
        self.reader.eval()
        data_iterator = self.get_data_iterator(args.dev_file, args.dev_batch_size, False, shuffle=False)

        log_result_step = args.log_batch_step
        all_results = []

        eval_top_docs = args.eval_top_docs
        for i, samples_batch in enumerate(data_iterator.iterate_data()):
            input = create_reader_input(self.tensorizer.get_pad_id(),
                                        samples_batch,
                                        args.passages_per_question_predict,
                                        args.sequence_length,
                                        args.max_n_answers,
                                        is_train=False, shuffle=False)

            input = ReaderBatch(**move_to_device(input._asdict(), args.device))
            attn_mask = self.tensorizer.get_attn_mask(input.input_ids)

            with torch.no_grad():
                start_logits, end_logits, relevance_logits = self.reader(input.input_ids, attn_mask)

            batch_predictions = self._get_best_prediction(start_logits, end_logits, relevance_logits, samples_batch,
                                                          passage_thresholds=eval_top_docs)

            all_results.extend(batch_predictions)

            if (i + 1) % log_result_step == 0:
                logger.info('Eval step: %d ', i)

        ems = defaultdict(list)

        for q_predictions in all_results:
            gold_answers = q_predictions.gold_answers
            span_predictions = q_predictions.predictions  # {top docs threshold -> SpanPrediction()}
            for (n, span_prediction) in span_predictions.items():
                em_hit = max([exact_match_score(span_prediction.prediction_text, ga) for ga in gold_answers])
                ems[n].append(em_hit)
        em = 0
        for n in sorted(ems.keys()):
            em = np.mean(ems[n])
            logger.info("n=%d\tEM %.2f" % (n, em * 100))

        if args.prediction_results_file:
            self._save_predictions(args.prediction_results_file, all_results)

        return em
Exemplo n.º 8
0
    def _calc_loss(self, input: ReaderBatch) -> torch.Tensor:
        cfg = self.cfg
        input = ReaderBatch(**move_to_device(input._asdict(), cfg.device))
        attn_mask = self.tensorizer.get_attn_mask(input.input_ids)
        questions_num, passages_per_question, _ = input.input_ids.size()

        if self.reader.training:
            # start_logits, end_logits, rank_logits = self.reader(input.input_ids, attn_mask)
            loss = self.reader(
                input.input_ids,
                attn_mask,
                input.start_positions,
                input.end_positions,
                input.answers_mask,
                use_simple_loss=getattr(cfg.train, "use_simple_loss", False),
                average_loss=getattr(cfg.train, "average_loss", False),
            )

        else:
            # TODO: remove?
            with torch.no_grad():
                start_logits, end_logits, rank_logits = self.reader(
                    input.input_ids, attn_mask)

            loss = compute_loss(
                input.start_positions,
                input.end_positions,
                input.answers_mask,
                start_logits,
                end_logits,
                rank_logits,
                questions_num,
                passages_per_question,
                use_simple_loss=getattr(cfg.train, "use_simple_loss", False),
                average=getattr(cfg.train, "average_loss", False),
            )
        if cfg.n_gpu > 1:
            loss = loss.mean()
        if cfg.train.gradient_accumulation_steps > 1:
            loss = loss / cfg.train.gradient_accumulation_steps

        return loss
Exemplo n.º 9
0
    def validate_average_rank(self) -> float:
        """
        Validates biencoder model using each question's gold passage's rank across the set of passages from the dataset.
        It generates vectors for specified amount of negative passages from each question (see --val_av_rank_xxx params)
        and stores them in RAM as well as question vectors.
        Then the similarity scores are calculted for the entire
        num_questions x (num_questions x num_passages_per_question) matrix and sorted per quesrtion.
        Each question's gold passage rank in that  sorted list of scores is averaged across all the questions.
        :return: averaged rank number
        """
        logger.info('Average rank validation ...')

        args = self.args
        self.biencoder.eval()
        distributed_factor = self.distributed_factor

        data_iterator = self.get_data_iterator(args.dev_file,
                                               args.dev_batch_size,
                                               shuffle=False)

        sub_batch_size = args.val_av_rank_bsz
        sim_score_f = BiEncoderNllLoss.get_similarity_function()
        q_represenations = []
        ctx_represenations = []
        positive_idx_per_question = []

        num_hard_negatives = args.val_av_rank_hard_neg
        num_other_negatives = args.val_av_rank_other_neg

        log_result_step = args.log_batch_step

        for i, samples_batch in enumerate(data_iterator.iterate_data()):
            # samples += 1
            if len(q_represenations
                   ) > args.val_av_rank_max_qs / distributed_factor:
                break

            biencoder_input = BiEncoder.create_biencoder_input(
                samples_batch,
                self.tensorizer,
                True,
                num_hard_negatives,
                num_other_negatives,
                shuffle=False)
            total_ctxs = len(ctx_represenations)
            ctxs_ids = biencoder_input.context_ids
            ctxs_segments = biencoder_input.ctx_segments
            bsz = ctxs_ids.size(0)

            # split contexts batch into sub batches since it is supposed to be too large to be processed in one batch
            for j, batch_start in enumerate(range(0, bsz, sub_batch_size)):

                q_ids, q_segments = (biencoder_input.question_ids, biencoder_input.question_segments) if j == 0 \
                    else (None, None)
                # notice: change here
                q_ids = move_to_device(q_ids, args.device)
                q_segments = move_to_device(q_segments, args.device)
                if j == 0 and args.n_gpu > 1 and q_ids.size(0) == 1:
                    # if we are in DP (but not in DDP) mode, all model input tensors should have batch size >1 or 0,
                    # otherwise the other input tensors will be split but only the first split will be called
                    continue

                ctx_ids_batch = move_to_device(
                    ctxs_ids[batch_start:batch_start + sub_batch_size],
                    args.device)
                ctx_seg_batch = move_to_device(
                    ctxs_segments[batch_start:batch_start + sub_batch_size],
                    args.device)

                q_attn_mask = move_to_device(
                    self.tensorizer.get_attn_mask(q_ids), args.device)
                ctx_attn_mask = move_to_device(
                    self.tensorizer.get_attn_mask(ctx_ids_batch), args.device)

                with torch.no_grad():
                    q_dense, ctx_dense = self.biencoder(
                        q_ids, q_segments, q_attn_mask, ctx_ids_batch,
                        ctx_seg_batch, ctx_attn_mask)

                if q_dense is not None:
                    q_represenations.extend(q_dense.cpu().split(1, dim=0))

                ctx_represenations.extend(ctx_dense.cpu().split(1, dim=0))

            batch_positive_idxs = biencoder_input.is_positive
            positive_idx_per_question.extend(
                [total_ctxs + v for v in batch_positive_idxs])

            if (i + 1) % log_result_step == 0:
                logger.info(
                    'Av.rank validation: step %d, computed ctx_vectors %d, q_vectors %d',
                    i, len(ctx_represenations), len(q_represenations))

        ctx_represenations = torch.cat(ctx_represenations, dim=0)
        q_represenations = torch.cat(q_represenations, dim=0)

        logger.info('Av.rank validation: total q_vectors size=%s',
                    q_represenations.size())
        logger.info('Av.rank validation: total ctx_vectors size=%s',
                    ctx_represenations.size())

        q_num = q_represenations.size(0)
        assert q_num == len(positive_idx_per_question)

        scores = sim_score_f(q_represenations, ctx_represenations)
        values, indices = torch.sort(scores, dim=1, descending=True)

        rank = 0
        for i, idx in enumerate(positive_idx_per_question):
            # aggregate the rank of the known gold passage in the sorted results for each question
            gold_idx = (indices[i] == idx).nonzero()
            rank += gold_idx.item()

        if distributed_factor > 1:
            # each node calcuated its own rank, exchange the information between node and calculate the "global" average rank
            # NOTE: the set of passages is still unique for every node
            eval_stats = all_gather_list([rank, q_num], max_size=100)
            for i, item in enumerate(eval_stats):
                remote_rank, remote_q_num = item
                if i != args.local_rank:
                    rank += remote_rank
                    q_num += remote_q_num

        av_rank = float(rank / q_num)
        logger.info('Av.rank validation: average rank %s, total questions=%d',
                    av_rank, q_num)
        return av_rank
Exemplo n.º 10
0
def _do_biencoder_fwd_pass(
    model: nn.Module,
    input: BiEncoderBatch,
    tensorizer: Tensorizer,
    loss_function,
    cfg,
    encoder_type: str,
    rep_positions_q=0,
    rep_positions_c=0,
    loss_scale: float = None,
    clustering: bool = False,
) -> Tuple[torch.Tensor, int]:

    input = BiEncoderBatch(**move_to_device(input._asdict(), cfg.device))

    q_attn_mask = tensorizer.get_attn_mask(input.question_ids)
    ctx_attn_mask = tensorizer.get_attn_mask(input.context_ids)

    if model.training:
        model_out = model(
            input.question_ids,
            input.question_segments,
            q_attn_mask,
            input.context_ids,
            input.ctx_segments,
            ctx_attn_mask,
            encoder_type=encoder_type,
            representation_token_pos_q=rep_positions_q,
            representation_token_pos_c=rep_positions_c,
        )
    else:
        with torch.no_grad():
            model_out = model(
                input.question_ids,
                input.question_segments,
                q_attn_mask,
                input.context_ids,
                input.ctx_segments,
                ctx_attn_mask,
                encoder_type=encoder_type,
                representation_token_pos_q=rep_positions_q,
                representation_token_pos_c=rep_positions_c,
            )

    local_q_vector, local_ctx_vectors = model_out

    if cfg.others.is_matching:  # MatchBiEncoder model
        loss, ml_is_correct, matching_is_correct = _calc_loss_matching(
            cfg,
            model,
            loss_function,
            local_q_vector,
            local_ctx_vectors,
            input.is_positive,
            input.hard_negatives,
            loss_scale=loss_scale,
        )
        ml_is_correct = ml_is_correct.sum().item()
        matching_is_correct = matching_is_correct.sum().item()
    else:
        loss, is_correct = calc_loss(
            cfg,
            loss_function,
            local_q_vector,
            local_ctx_vectors,
            input.is_positive,
            input.hard_negatives,
            loss_scale=loss_scale,
        )
        is_correct = is_correct.sum().item()

    if cfg.n_gpu > 1:
        loss = loss.mean()
    if cfg.train.gradient_accumulation_steps > 1:
        loss = loss / cfg.gradient_accumulation_steps

    if clustering:
        assert not cfg.others.is_matching
        return loss, is_correct, model_out
    elif cfg.others.is_matching:
        return loss, ml_is_correct, matching_is_correct
    else:
        return loss, is_correct
Exemplo n.º 11
0
    def validate(self):
        logger.info("Validation ...")
        cfg = self.cfg
        self.reader.eval()
        if self.dev_iterator is None:
            self.dev_iterator = self.get_data_iterator(
                cfg.dev_files, cfg.train.dev_batch_size, False, shuffle=False)

        log_result_step = cfg.train.log_batch_step // 4  # validation needs to be more verbose
        all_results = []

        eval_top_docs = cfg.eval_top_docs
        for i, samples_batch in enumerate(self.dev_iterator.iterate_ds_data()):
            input = create_reader_input(
                self.wiki_data,
                self.tensorizer,
                samples_batch,
                cfg.passages_per_question_predict,
                cfg.encoder.sequence_length,
                cfg.max_n_answers,
                is_train=False,
                shuffle=False,
            )

            input = ReaderBatch(**move_to_device(input._asdict(), cfg.device))
            attn_mask = self.tensorizer.get_attn_mask(input.input_ids)

            with torch.no_grad():
                start_logits, end_logits, relevance_logits = self.reader(
                    input.input_ids, attn_mask)

            batch_predictions = get_best_prediction(
                self.cfg.max_answer_length,
                self.tensorizer,
                start_logits,
                end_logits,
                relevance_logits,
                samples_batch,
                passage_thresholds=eval_top_docs,
            )

            all_results.extend(batch_predictions)

            if (i + 1) % log_result_step == 0:
                logger.info("Eval step: %d ", i)

        ems = defaultdict(list)
        f1s = defaultdict(list)

        for q_predictions in all_results:
            gold_answers = q_predictions.gold_answers
            span_predictions = (q_predictions.predictions
                                )  # {top docs threshold -> SpanPrediction()}
            for (n, span_prediction) in span_predictions.items():
                # Exact match
                em_hit = max([
                    exact_match_score(span_prediction.prediction_text, ga)
                    for ga in gold_answers
                ])
                ems[n].append(em_hit)

                # F1 score
                f1_hit = max([
                    f1_score(span_prediction.prediction_text, ga)
                    for ga in gold_answers
                ])
                f1s[n].append(f1_hit)

        # Sync between GPUs
        ems, f1s = gather(self.cfg, [ems, f1s])

        em = 0
        for n in sorted(ems[0].keys()):
            ems_n = sum([em[n] for em in ems], [])  # gather and concatenate
            em = np.mean(ems_n)
            logger.info("n=%d\tEM %.2f" % (n, em * 100))

        for n in sorted(f1s[0].keys()):
            f1s_n = sum([f1[n] for f1 in f1s], [])  # gather and concatenate
            f1 = np.mean(f1s_n)
            logger.info("n=%d\tF1 %.2f" % (n, f1 * 100))

        if cfg.prediction_results_file:
            self._save_predictions(cfg.prediction_results_file, all_results)

        return em
Exemplo n.º 12
0
def gen_ctx_vectors(
    cfg: DictConfig,
    ctx_rows: List[Tuple[object, BiEncoderPassage]],
    q_rows: List[object],
    model: nn.Module,
    tensorizer: Tensorizer,
    insert_title: bool = True,
) -> List[Tuple[object, np.array]]:
    n = len(ctx_rows)
    bsz = cfg.batch_size
    total = 0
    results = []
    for j, batch_start in enumerate(range(0, n, bsz)):
        # Passage preprocess # TODO; max seq length check
        batch = ctx_rows[batch_start:batch_start + bsz]
        batch_token_tensors = [
            tensorizer.text_to_tensor(
                ctx[1].text, title=ctx[1].title if insert_title else None)
            for ctx in batch
        ]

        ctx_ids_batch = move_to_device(torch.stack(batch_token_tensors, dim=0),
                                       cfg.device)
        ctx_seg_batch = move_to_device(torch.zeros_like(ctx_ids_batch),
                                       cfg.device)
        ctx_attn_mask = move_to_device(tensorizer.get_attn_mask(ctx_ids_batch),
                                       cfg.device)

        # Question preprocess
        q_batch = q_rows[batch_start:batch_start + bsz]
        q_batch_token_tensors = [
            tensorizer.text_to_tensor(qq) for qq in q_batch
        ]

        q_ids_batch = move_to_device(torch.stack(q_batch_token_tensors, dim=0),
                                     cfg.device)
        q_seg_batch = move_to_device(torch.zeros_like(q_ids_batch), cfg.device)
        q_attn_mask = move_to_device(tensorizer.get_attn_mask(q_ids_batch),
                                     cfg.device)

        # Selector
        from dpr.data.biencoder_data import DEFAULT_SELECTOR
        selector = DEFAULT_SELECTOR
        rep_positions = selector.get_positions(q_ids_batch, tensorizer)

        with torch.no_grad():
            q_dense, ctx_dense = model(
                q_ids_batch,
                q_seg_batch,
                q_attn_mask,
                ctx_ids_batch,
                ctx_seg_batch,
                ctx_attn_mask,
                representation_token_pos=rep_positions,
            )
        q_dense = q_dense.cpu()
        ctx_dense = ctx_dense.cpu()
        ctx_ids = [r[0] for r in batch]

        assert len(ctx_ids) == q_dense.size(0) == ctx_dense.size(0)
        total += len(ctx_ids)

        results.extend([(ctx_ids[i], q_dense[i].numpy(), ctx_dense[i].numpy(),
                         q_dense[i].numpy().dot(ctx_dense[i].numpy()))
                        for i in range(q_dense.size(0))])

        if total % 10 == 0:
            logger.info("Encoded questions / passages %d", total)
            # break
    return results
Exemplo n.º 13
0
def do_ofa_fwd_pass(
    trainer,
    mode: str,
    backward: bool,  # whether to backward loss
    step: bool,  # whether to perform `optimizer.step()`
    biencoder_input: BiEncoderBatch,
    biencoder_config: BiEncoderTrainingConfig,
    reader_inputs: List[ReaderBatch],
    reader_config: ReaderTrainingConfig,
    inference_only: bool = False,
) -> Union[ForwardPassOutputsTrain, BiEncoderPredictionBatch,
           List[ReaderPredictionBatch], Tuple[BiEncoderPredictionBatch,
                                              List[ReaderPredictionBatch]], ]:
    """
    Note: if `inference_only` is set to True:
        1. No loss is computed.
        2. No backward pass is performed.
        3. All predictions are transformed to CPU to save memory.
    """

    assert mode in ["retriever", "reader", "both"], f"Invalid mode: {mode}"
    if inference_only:
        assert (not backward) and (not step) and (not trainer.model.training)
    biencoder_is_correct = None
    biencoder_preds = None
    reader_input_tot = None
    reader_preds_tot = None

    # Forward pass and backward pass for biencoder
    if mode in ["retriever", "both"]:
        biencoder_input = BiEncoderBatch(
            **move_to_device(biencoder_input._asdict(), trainer.cfg.device))

        if trainer.model.training:
            biencoder_preds: BiEncoderPredictionBatch = trainer.model(
                mode="retriever",
                biencoder_batch=biencoder_input,
                biencoder_config=biencoder_config,
                reader_batch=None,
                reader_config=None,
            )

        else:
            with torch.no_grad():
                biencoder_preds: BiEncoderPredictionBatch = trainer.model(
                    mode="retriever",
                    biencoder_batch=biencoder_input,
                    biencoder_config=biencoder_config,
                    reader_batch=None,
                    reader_config=None,
                )

        if not inference_only:
            # Calculate biencoder loss
            biencoder_loss, biencoder_is_correct = calc_loss_biencoder(
                cfg=trainer.cfg,
                loss_function=trainer.biencoder_loss_function,
                local_q_vector=biencoder_preds.question_vector,
                local_ctx_vectors=biencoder_preds.context_vector,
                local_positive_idxs=biencoder_input.is_positive,
                local_hard_negatives_idxs=biencoder_input.hard_negatives,
                loss_scale=None,
            )
            biencoder_is_correct = biencoder_is_correct.sum().item()
            biencoder_input = BiEncoderBatch(
                **move_to_device(biencoder_input._asdict(), "cpu"))

            # Re-calibrate loss
            if trainer.cfg.n_gpu > 1:
                biencoder_loss = biencoder_loss.mean()
            if trainer.cfg.train.gradient_accumulation_steps > 1:
                biencoder_loss = biencoder_loss / trainer.cfg.gradient_accumulation_steps

            if backward:
                assert trainer.model.training, "Model is not in training mode!"
                trainer.backward(
                    loss=biencoder_loss,
                    optimizer=trainer.biencoder_optimizer,
                    scheduler=trainer.biencoder_scheduler,
                    step=step,
                )
        else:
            biencoder_input = BiEncoderBatch(
                **move_to_device(biencoder_input._asdict(), "cpu"))

        biencoder_preds = BiEncoderPredictionBatch(
            **move_to_device(biencoder_preds._asdict(), "cpu"))

    # Forward and backward pass for reader
    if mode in ["reader", "both"]:
        reader_total_loss = 0
        reader_input_tot: List[ReaderBatch] = []
        reader_preds_tot: List[ReaderPredictionBatch] = []

        for reader_input in reader_inputs:
            reader_input = ReaderBatch(
                **move_to_device(reader_input._asdict(), trainer.cfg.device))

            if trainer.model.training:
                reader_preds: ReaderPredictionBatch = trainer.model(
                    mode="reader",
                    biencoder_batch=None,
                    biencoder_config=None,
                    reader_batch=reader_input,
                    reader_config=reader_config,
                )

                reader_loss = reader_preds.total_loss / len(
                    reader_inputs)  # scale by number of sub batches
                reader_total_loss += reader_loss

                # Re-calibrate loss
                if trainer.cfg.n_gpu > 1:
                    reader_loss = reader_loss.mean()
                if trainer.cfg.train.gradient_accumulation_steps > 1:
                    reader_loss = reader_loss / trainer.cfg.gradient_accumulation_steps

                if backward:
                    assert trainer.model.training, "Model is not in training mode!"
                    trainer.backward(
                        loss=reader_loss,
                        optimizer=trainer.reader_optimizer,
                        scheduler=trainer.reader_scheduler,
                        step=step,
                    )

            else:
                with torch.no_grad():
                    reader_preds: ReaderPredictionBatch = trainer.model(
                        mode="reader",
                        biencoder_batch=None,
                        biencoder_config=None,
                        reader_batch=reader_input,
                        reader_config=reader_config,
                    )

                if not inference_only:
                    questions_num, passages_per_question, _ = reader_input.input_ids.size(
                    )
                    reader_total_loss = calc_loss_reader(
                        start_positions=reader_input.start_positions,
                        end_positions=reader_input.end_positions,
                        answers_mask=reader_input.answers_mask,
                        start_logits=reader_preds.start_logits,
                        end_logits=reader_preds.end_logits,
                        relevance_logits=reader_preds.relevance_logits,
                        N=questions_num,
                        M=passages_per_question,
                        use_simple_loss=reader_config.use_simple_loss,
                        average=reader_config.average_loss,
                    )

            reader_input = ReaderBatch(
                **move_to_device(reader_input._asdict(), "cpu"))
            reader_input_tot.append(reader_input)

            reader_preds = ReaderPredictionBatch(
                **move_to_device(reader_preds._asdict(), "cpu"))
            reader_preds_tot.append(reader_preds)

    if inference_only:
        if mode == "retriever":
            return biencoder_preds
        elif mode == "reader":
            return reader_preds_tot
        else:
            return biencoder_preds, reader_preds_tot

    else:
        # Total loss; for now use 1:1 weights
        if mode == "retriever":
            loss = biencoder_loss
        elif mode == "reader":
            loss = reader_total_loss
        else:
            loss = biencoder_loss + reader_total_loss

        outputs = ForwardPassOutputsTrain(
            loss=loss,
            biencoder_is_correct=biencoder_is_correct,
            biencoder_input=biencoder_input,
            biencoder_preds=biencoder_preds,
            reader_input=reader_input_tot,
            reader_preds=reader_preds_tot,
        )
        return outputs
Exemplo n.º 14
0
    def validate(self):
        logger.info('Validation ...')
        args = self.args
        self.reader.eval()
        data_iterator = self.get_data_iterator(args.dev_file,
                                               args.dev_batch_size,
                                               False,
                                               shuffle=False)

        log_result_step = args.log_batch_step
        all_results = []

        eval_top_docs = args.eval_top_docs
        for i, samples_batch in enumerate(data_iterator.iterate_data()):
            input = create_reader_input(self.tensorizer.get_pad_id(),
                                        samples_batch,
                                        args.passages_per_question_predict,
                                        args.sequence_length,
                                        args.max_n_answers,
                                        is_train=False,
                                        shuffle=False)

            input = ReaderBatch(**move_to_device(input._asdict(), args.device))
            attn_mask = self.tensorizer.get_attn_mask(input.input_ids)

            with torch.no_grad():
                start_logits, end_logits, relevance_logits = self.reader(
                    input.input_ids, attn_mask)

            batch_predictions = self._get_best_prediction(
                start_logits,
                end_logits,
                relevance_logits,
                samples_batch,
                passage_thresholds=eval_top_docs)

            all_results.extend(batch_predictions)

            if (i + 1) % log_result_step == 0:
                logger.info('Eval step: %d ', i)

        if args.prediction_results_file:
            self._save_predictions(args.prediction_results_file, all_results)

        em = 0  # exact match
        cm = 0  # char match
        rouge_scorer = Rouge()
        bleu_scorer = Bleu()
        if not args.test_only:
            ems = defaultdict(list)
            cms = defaultdict(list)
            gts = defaultdict(list)
            preds = defaultdict(list)
            top1 = defaultdict(list)

            for q_predictions in all_results:
                gold_answers = q_predictions.gold_answers
                span_predictions = q_predictions.predictions  # {top docs threshold -> SpanPrediction()}
                for (n, span_prediction) in span_predictions.items():
                    em_hit = max([
                        exact_match_score(span_prediction.prediction_text, ga)
                        for ga in gold_answers
                    ])
                    cm_hit = max([
                        char_match_score(span_prediction.prediction_text, ga)
                        for ga in gold_answers
                    ])
                    ems[n].append(em_hit)
                    cms[n].append(cm_hit)
                    # for bleu/rouge later
                    gts[n].append(gold_answers)
                    preds[n].append(span_prediction.prediction_text)
                    # for qa_classify top1
                    has_answer = q_predictions.passages_has_answer[
                        span_prediction.passage_index]
                    top1[n].append(float(has_answer))

            for n in sorted(ems.keys()):
                em = np.mean(ems[n])
                cm = np.mean(cms[n])
                bleu = bleu_scorer.compute_score(gts[n], preds[n])
                rouge = rouge_scorer.compute_score(gts[n], preds[n])
                t1 = np.mean(top1[n])
                mean_score = (em + cm) / 2
                logger.info(
                    "n=%d\tEM %.2f\tCM %.2f\tScore %.2f\tTop-1 %.2f\n" %
                    (n, em * 100, cm * 100, mean_score * 100, t1 * 100))
                # logger.info("n=%d\tEM %.2f\tCM %.2f\tRouge-L %.2f\tBLEU-4 %.2f\tTop-1 %.2f\n" % (n, em * 100, cm * 100, rouge * 100, bleu * 100, t1 * 100))

        return em