Esempio n. 1
0
    def validate_nll(self) -> float:
        logger.info('NLL validation ...')
        args = self.args
        self.biencoder.eval()
        data_iterator = self.get_data_iterator(args.dev_file, args.dev_batch_size, shuffle=False)

        total_loss = 0.0
        start_time = time.time()
        total_correct_predictions = 0
        num_hard_negatives = args.hard_negatives
        num_other_negatives = args.other_negatives
        log_result_step = args.log_batch_step
        batches = 0
        for i, samples_batch in enumerate(data_iterator.iterate_data()):
            biencoder_input = BiEncoder.create_biencoder_input(samples_batch, self.tensorizer,
                                                               True,
                                                               num_hard_negatives, num_other_negatives, shuffle=False)

            loss, correct_cnt = _do_biencoder_fwd_pass(self.biencoder, biencoder_input, self.tensorizer, args)
            total_loss += loss.item()
            total_correct_predictions += correct_cnt
            batches += 1
            if (i + 1) % log_result_step == 0:
                logger.info('Eval step: %d , used_time=%f sec., loss=%f ', i, time.time() - start_time, loss.item())

        total_loss = total_loss / batches
        total_samples = batches * args.dev_batch_size * self.distributed_factor
        correct_ratio = float(total_correct_predictions / total_samples)
        logger.info('NLL Validation: loss = %f. correct prediction ratio  %d/%d ~  %f', total_loss,
                    total_correct_predictions,
                    total_samples,
                    correct_ratio
                    )
        return total_loss
Esempio n. 2
0
def create_ofa_input(
    mode: str,
    wiki_data: TokenizedWikipediaPassages,
    tensorizer: Tensorizer,
    samples: List[Tuple[BiEncoderSampleTokenized, ReaderSample]],
    biencoder_config: BiEncoderDataConfig,
    reader_config: ReaderDataConfig,
) -> Union[BiEncoderBatch, List[ReaderBatch], Tuple[BiEncoderBatch,
                                                    List[ReaderBatch]], ]:

    assert mode in ["retriever", "reader", "both"], f"Invalid mode: {mode}"
    retriever_samples, reader_samples = zip(*samples)

    # Retriever (bi-encoder)
    if mode in ["retriever", "both"]:
        biencoder_batch = BiEncoder.create_biencoder_input(
            samples=retriever_samples,
            tensorizer=tensorizer,
            insert_title=biencoder_config.insert_title,
            num_hard_negatives=biencoder_config.num_hard_negatives,
            num_other_negatives=biencoder_config.num_other_negatives,
            shuffle=biencoder_config.shuffle,
            shuffle_positives=biencoder_config.shuffle_positives,
            hard_neg_fallback=biencoder_config.hard_neg_fallback,
            query_token=biencoder_config.query_token,
        )

    # Reader
    if mode in ["reader", "both"]:
        num_samples = len(samples)
        num_sub_batches = reader_config.num_sub_batches
        assert num_sub_batches > 0

        sub_batch_size = math.ceil(num_samples / num_sub_batches)
        reader_batches: List[ReaderBatch] = []

        for batch_i in range(num_sub_batches):
            start = batch_i * sub_batch_size
            end = min(start + sub_batch_size, num_samples)
            if start >= end:
                break

            reader_batch = create_reader_input(
                wiki_data=wiki_data,
                tensorizer=tensorizer,
                samples=reader_samples[start:end],
                passages_per_question=reader_config.passages_per_question,
                max_length=reader_config.max_length,
                max_n_answers=reader_config.max_n_answers,
                is_train=reader_config.is_train,
                shuffle=reader_config.shuffle,
            )
            reader_batches.append(reader_batch)

    if mode == "retriever":
        return biencoder_batch
    elif mode == "reader":
        return reader_batches
    else:
        return biencoder_batch, reader_batches
Esempio n. 3
0
    def _train_epoch(
            self,
            scheduler,
            epoch: int,
            eval_step: int,
            train_data_iterator: ShardedDataIterator,
    ):

        args = self.args
        rolling_train_loss = 0.0
        epoch_loss = 0
        epoch_correct_predictions = 0

        log_result_step = args.log_batch_step
        rolling_loss_step = args.train_rolling_loss_step
        num_hard_negatives = args.hard_negatives
        num_other_negatives = args.other_negatives
        seed = args.seed
        self.biencoder.train()
        epoch_batches = train_data_iterator.max_iterations
        data_iteration = 0
        for i, samples_batch in enumerate(
                train_data_iterator.iterate_data(epoch=epoch)
        ):

            # to be able to resume shuffled ctx- pools
            data_iteration = train_data_iterator.get_iteration()
            random.seed(seed + epoch + data_iteration)
            biencoder_batch = BiEncoder.create_biencoder_input(
                samples_batch,
                self.tensorizer,
                True,
                num_hard_negatives,
                num_other_negatives,
                shuffle=True,
                shuffle_positives=args.shuffle_positive_ctx,
            )

            loss, correct_cnt = _do_biencoder_fwd_pass(
                self.biencoder, biencoder_batch, self.tensorizer, args
            )

            epoch_correct_predictions += correct_cnt
            epoch_loss += loss.item()
            rolling_train_loss += loss.item()

            if args.fp16:
                from apex import amp

                with amp.scale_loss(loss, self.optimizer) as scaled_loss:
                    scaled_loss.backward()
                if args.max_grad_norm > 0:
                    torch.nn.utils.clip_grad_norm_(
                        amp.master_params(self.optimizer), args.max_grad_norm
                    )
            else:
                loss.backward()
                if args.max_grad_norm > 0:
                    torch.nn.utils.clip_grad_norm_(
                        self.biencoder.parameters(), args.max_grad_norm
                    )

            if (i + 1) % args.gradient_accumulation_steps == 0:
                self.optimizer.step()
                scheduler.step()
                self.biencoder.zero_grad()

            if i % log_result_step == 0:
                lr = self.optimizer.param_groups[0]["lr"]
                logger.info(
                    "Epoch: %d: Step: %d/%d, loss=%f, lr=%f",
                    epoch,
                    data_iteration,
                    epoch_batches,
                    loss.item(),
                    lr,
                )

            if (i + 1) % rolling_loss_step == 0:
                logger.info("Train batch %d", data_iteration)
                latest_rolling_train_av_loss = rolling_train_loss / rolling_loss_step
                logger.info(
                    "Avg. loss per last %d batches: %f",
                    rolling_loss_step,
                    latest_rolling_train_av_loss,
                )
                rolling_train_loss = 0.0

            if data_iteration % eval_step == 0:
                logger.info(
                    "Validation: Epoch: %d Step: %d/%d",
                    epoch,
                    data_iteration,
                    epoch_batches,
                )
                self.validate_and_save(
                    epoch, train_data_iterator.get_iteration(), scheduler
                )
                self.biencoder.train()

        self.validate_and_save(epoch, data_iteration, scheduler)

        epoch_loss = (epoch_loss / epoch_batches) if epoch_batches > 0 else 0
        logger.info("Av Loss per epoch=%f", epoch_loss)
        logger.info("epoch total correct predictions=%d", epoch_correct_predictions)
Esempio n. 4
0
    def validate_average_rank(self) -> float:
        """
        Validates biencoder model using each question's gold passage's rank across the set of passages from the dataset.
        It generates vectors for specified amount of negative passages from each question (see --val_av_rank_xxx params)
        and stores them in RAM as well as question vectors.
        Then the similarity scores are calculted for the entire
        num_questions x (num_questions x num_passages_per_question) matrix and sorted per quesrtion.
        Each question's gold passage rank in that  sorted list of scores is averaged across all the questions.
        :return: averaged rank number
        """
        logger.info("Average rank validation ...")

        args = self.args
        self.biencoder.eval()
        distributed_factor = self.distributed_factor

        data_iterator = self.get_data_iterator(
            args.dev_file, args.dev_batch_size, shuffle=False
        )

        sub_batch_size = args.val_av_rank_bsz
        sim_score_f = BiEncoderNllLoss.get_similarity_function()
        q_represenations = []
        ctx_represenations = []
        positive_idx_per_question = []

        num_hard_negatives = args.val_av_rank_hard_neg
        num_other_negatives = args.val_av_rank_other_neg

        log_result_step = args.log_batch_step

        for i, samples_batch in enumerate(data_iterator.iterate_data()):
            # samples += 1
            if len(q_represenations) > args.val_av_rank_max_qs / distributed_factor:
                break

            biencoder_input = BiEncoder.create_biencoder_input(
                samples_batch,
                self.tensorizer,
                True,
                num_hard_negatives,
                num_other_negatives,
                shuffle=False,
            )
            total_ctxs = len(ctx_represenations)
            ctxs_ids = biencoder_input.context_ids
            ctxs_segments = biencoder_input.ctx_segments
            bsz = ctxs_ids.size(0)

            # split contexts batch into sub batches since it is supposed to be too large to be processed in one batch
            for j, batch_start in enumerate(range(0, bsz, sub_batch_size)):

                q_ids, q_segments = (
                    (biencoder_input.question_ids, biencoder_input.question_segments)
                    if j == 0
                    else (None, None)
                )

                if j == 0 and args.n_gpu > 1 and q_ids.size(0) == 1:
                    # if we are in DP (but not in DDP) mode, all model input tensors should have batch size >1 or 0,
                    # otherwise the other input tensors will be split but only the first split will be called
                    continue

                ctx_ids_batch = ctxs_ids[batch_start: batch_start + sub_batch_size]
                ctx_seg_batch = ctxs_segments[
                                batch_start: batch_start + sub_batch_size
                                ]

                q_attn_mask = self.tensorizer.get_attn_mask(q_ids)
                ctx_attn_mask = self.tensorizer.get_attn_mask(ctx_ids_batch)
                with torch.no_grad():
                    q_dense, ctx_dense = self.biencoder(
                        q_ids,
                        q_segments,
                        q_attn_mask,
                        ctx_ids_batch,
                        ctx_seg_batch,
                        ctx_attn_mask,
                    )

                if q_dense is not None:
                    q_represenations.extend(q_dense.cpu().split(1, dim=0))

                ctx_represenations.extend(ctx_dense.cpu().split(1, dim=0))

            batch_positive_idxs = biencoder_input.is_positive
            positive_idx_per_question.extend(
                [total_ctxs + v for v in batch_positive_idxs]
            )

            if (i + 1) % log_result_step == 0:
                logger.info(
                    "Av.rank validation: step %d, computed ctx_vectors %d, q_vectors %d",
                    i,
                    len(ctx_represenations),
                    len(q_represenations),
                )

        ctx_represenations = torch.cat(ctx_represenations, dim=0)
        q_represenations = torch.cat(q_represenations, dim=0)

        logger.info(
            "Av.rank validation: total q_vectors size=%s", q_represenations.size()
        )
        logger.info(
            "Av.rank validation: total ctx_vectors size=%s", ctx_represenations.size()
        )

        q_num = q_represenations.size(0)
        assert q_num == len(positive_idx_per_question)

        scores = sim_score_f(q_represenations, ctx_represenations)
        values, indices = torch.sort(scores, dim=1, descending=True)

        rank = 0
        for i, idx in enumerate(positive_idx_per_question):
            # aggregate the rank of the known gold passage in the sorted results for each question
            gold_idx = (indices[i] == idx).nonzero()
            rank += gold_idx.item()

        if distributed_factor > 1:
            # each node calcuated its own rank, exchange the information between node and calculate the "global" average rank
            # NOTE: the set of passages is still unique for every node
            eval_stats = all_gather_list([rank, q_num], max_size=100)
            for i, item in enumerate(eval_stats):
                remote_rank, remote_q_num = item
                if i != args.local_rank:
                    rank += remote_rank
                    q_num += remote_q_num

        av_rank = float(rank / q_num)
        logger.info(
            "Av.rank validation: average rank %s, total questions=%d", av_rank, q_num
        )
        return av_rank
Esempio n. 5
0
    def _train_epoch(
        self,
        scheduler,
        epoch: int,
        eval_step: int,
        train_data_iterator: MultiSetDataIterator,
    ):

        cfg = self.cfg
        rolling_train_loss = 0.0
        epoch_loss = 0
        epoch_correct_predictions, epoch_correct_predictions_matching = 0, 0

        log_result_step = cfg.train.log_batch_step
        rolling_loss_step = cfg.train.train_rolling_loss_step
        num_hard_negatives = cfg.train.hard_negatives
        num_other_negatives = cfg.train.other_negatives
        seed = cfg.seed
        self.biencoder.train()
        epoch_batches = train_data_iterator.max_iterations
        data_iteration = 0

        dataset = 0
        for i, samples_batch in enumerate(
                train_data_iterator.iterate_ds_data(epoch=epoch)):
            if isinstance(samples_batch, Tuple):
                samples_batch, dataset = samples_batch

            ds_cfg = self.ds_cfg.train_datasets[dataset]
            special_token = ds_cfg.special_token
            encoder_type = ds_cfg.encoder_type
            shuffle_positives = ds_cfg.shuffle_positives

            # to be able to resume shuffled ctx- pools
            data_iteration = train_data_iterator.get_iteration()
            random.seed(seed + epoch + data_iteration)

            biencoder_batch = BiEncoder.create_biencoder_input(
                samples_batch,
                self.tensorizer,
                True,
                num_hard_negatives,
                num_other_negatives,
                shuffle=True,
                shuffle_positives=shuffle_positives,
                query_token=special_token,
            )

            # get the token to be used for representation selection
            from dpr.data.biencoder_data import DEFAULT_SELECTOR

            selector = ds_cfg.selector if ds_cfg else DEFAULT_SELECTOR

            rep_positions_q = selector.get_positions(
                biencoder_batch.question_ids, self.tensorizer, self.biencoder)
            rep_positions_c = selector.get_positions(
                biencoder_batch.context_ids, self.tensorizer, self.biencoder)

            loss_scale = (cfg.loss_scale_factors[dataset]
                          if cfg.loss_scale_factors else None)
            outp = _do_biencoder_fwd_pass(
                self.biencoder,
                biencoder_batch,
                self.tensorizer,
                self.loss_function,
                cfg,
                encoder_type=encoder_type,
                rep_positions_q=rep_positions_q,
                rep_positions_c=rep_positions_c,
                loss_scale=loss_scale,
                clustering=self.clustering,
            )
            if self.clustering:
                loss, correct_cnt, (question_vector, context_vector) = outp
                question_vector = question_vector.clone().detach().cpu().numpy(
                )
                context_vector = context_vector.clone().detach().cpu().numpy()
                model_outs = ForwardPassOutputsTrain(
                    loss=None,
                    biencoder_is_correct=None,
                    biencoder_input=biencoder_batch,
                    biencoder_preds=BiEncoderPredictionBatch(
                        question_vector=question_vector,
                        context_vector=context_vector,
                    ),
                    reader_input=None,
                    reader_preds=None,
                )
                iterator: ShardedDataIteratorClustering = train_data_iterator.iterables[
                    dataset]
                iterator.record_predictions(epoch=epoch, model_outs=model_outs)

            elif cfg.others.is_matching:
                loss, correct_cnt, correct_cnt_matching = outp
                epoch_correct_predictions_matching += correct_cnt_matching
            else:
                loss, correct_cnt = outp

            epoch_correct_predictions += correct_cnt
            epoch_loss += loss.item()
            rolling_train_loss += loss.item()

            if cfg.fp16:
                from apex import amp

                with amp.scale_loss(loss, self.optimizer) as scaled_loss:
                    scaled_loss.backward()
                if cfg.train.max_grad_norm > 0:
                    torch.nn.utils.clip_grad_norm_(
                        amp.master_params(self.optimizer),
                        cfg.train.max_grad_norm)
            else:
                loss.backward()
                if cfg.train.max_grad_norm > 0:
                    torch.nn.utils.clip_grad_norm_(self.biencoder.parameters(),
                                                   cfg.train.max_grad_norm)

            if (i + 1) % cfg.train.gradient_accumulation_steps == 0:
                self.optimizer.step()
                scheduler.step()
                self.biencoder.zero_grad()

            if i % log_result_step == 0:
                lr = self.optimizer.param_groups[0]["lr"]
                logger.info(
                    "Epoch: %d: Step: %d/%d, loss=%f, lr=%f",
                    epoch,
                    data_iteration,
                    epoch_batches,
                    loss.item(),
                    lr,
                )

            if (i + 1) % rolling_loss_step == 0:
                logger.info("Train batch %d", data_iteration)
                latest_rolling_train_av_loss = rolling_train_loss / rolling_loss_step
                logger.info(
                    "Avg. loss per last %d batches: %f",
                    rolling_loss_step,
                    latest_rolling_train_av_loss,
                )
                rolling_train_loss = 0.0

            if data_iteration % eval_step == 0:
                logger.info(
                    "rank=%d, Validation: Epoch: %d Step: %d/%d",
                    cfg.local_rank,
                    epoch,
                    data_iteration,
                    epoch_batches,
                )
                self.validate_and_save(epoch,
                                       train_data_iterator.get_iteration(),
                                       scheduler)
                self.biencoder.train()

        logger.info("Epoch finished on %d", cfg.local_rank)

        # If we just evaluate at the last iteration, we don't need to evaluate again
        if data_iteration % eval_step != 0:
            self.validate_and_save(epoch, data_iteration, scheduler)

        epoch_loss = (epoch_loss / epoch_batches) if epoch_batches > 0 else 0
        logger.info("Av Loss per epoch=%f", epoch_loss)
        logger.info("epoch total correct predictions=%d",
                    epoch_correct_predictions)
        if cfg.others.is_matching:
            logger.info("epoch total correct matching predictions=%d",
                        epoch_correct_predictions_matching)
Esempio n. 6
0
    def validate_average_rank(self) -> float:
        """
        Validates biencoder model using each question's gold passage's rank across the set of passages from the dataset.
        It generates vectors for specified amount of negative passages from each question (see --val_av_rank_xxx params)
        and stores them in RAM as well as question vectors.
        Then the similarity scores are calculted for the entire
        num_questions x (num_questions x num_passages_per_question) matrix and sorted per quesrtion.
        Each question's gold passage rank in that  sorted list of scores is averaged across all the questions.
        :return: averaged rank number
        """
        logger.info("Average rank validation ...")

        cfg = self.cfg
        self.biencoder.eval()
        distributed_factor = self.distributed_factor

        if not self.dev_iterator:
            self.dev_iterator = self.get_data_iterator(
                cfg.train.dev_batch_size,
                False,
                shuffle=False,
                rank=cfg.local_rank)
        data_iterator = self.dev_iterator

        sub_batch_size = cfg.train.val_av_rank_bsz
        sim_score_f = self.loss_function.get_similarity_function()
        q_represenations = []
        ctx_represenations = []
        positive_idx_per_question = []

        num_hard_negatives = cfg.train.val_av_rank_hard_neg
        num_other_negatives = cfg.train.val_av_rank_other_neg

        dataset = 0
        for i, samples_batch in enumerate(data_iterator.iterate_ds_data()):
            # samples += 1
            if (len(q_represenations) >
                    cfg.train.val_av_rank_max_qs / distributed_factor):
                break

            if isinstance(samples_batch, Tuple):
                samples_batch, dataset = samples_batch

            biencoder_input = BiEncoder.create_biencoder_input(
                samples_batch,
                self.tensorizer,
                True,
                num_hard_negatives,
                num_other_negatives,
                shuffle=False,
            )
            total_ctxs = len(ctx_represenations)
            ctxs_ids = biencoder_input.context_ids.to(cfg.device)
            ctxs_segments = biencoder_input.ctx_segments.to(cfg.device)
            bsz = ctxs_ids.size(0)

            # get the token to be used for representation selection
            ds_cfg = self.ds_cfg.dev_datasets[dataset]
            encoder_type = ds_cfg.encoder_type
            rep_positions_q = ds_cfg.selector.get_positions(
                biencoder_input.question_ids, self.tensorizer, self.biencoder)
            rep_positions_c = ds_cfg.selector.get_positions(
                biencoder_input.context_ids, self.tensorizer, self.biencoder)

            # split contexts batch into sub batches since it is supposed to be too large to be processed in one batch
            for j, batch_start in enumerate(range(0, bsz, sub_batch_size)):

                q_ids, q_segments = ((biencoder_input.question_ids.to(
                    cfg.device),
                                      biencoder_input.question_segments.to(
                                          cfg.device)) if j == 0 else
                                     (None, None))

                if j == 0 and cfg.n_gpu > 1 and q_ids.size(0) == 1:
                    # if we are in DP (but not in DDP) mode, all model input tensors should have batch size >1 or 0,
                    # otherwise the other input tensors will be split but only the first split will be called
                    continue

                ctx_ids_batch = ctxs_ids[batch_start:batch_start +
                                         sub_batch_size]
                ctx_seg_batch = ctxs_segments[batch_start:batch_start +
                                              sub_batch_size]

                q_attn_mask = self.tensorizer.get_attn_mask(q_ids)
                q_attn_mask = q_attn_mask if q_ids is not None else q_attn_mask
                ctx_attn_mask = self.tensorizer.get_attn_mask(
                    ctx_ids_batch).to(cfg.device)
                with torch.no_grad():
                    q_dense, ctx_dense = self.biencoder(
                        q_ids,
                        q_segments,
                        q_attn_mask,
                        ctx_ids_batch,
                        ctx_seg_batch,
                        ctx_attn_mask,
                        encoder_type=encoder_type,
                        representation_token_pos_q=rep_positions_q,
                        representation_token_pos_c=rep_positions_c,
                    )

                if q_dense is not None:
                    q_represenations.extend(q_dense.cpu().split(1, dim=0))

                ctx_represenations.extend(ctx_dense.cpu().split(1, dim=0))

            batch_positive_idxs = biencoder_input.is_positive
            positive_idx_per_question.extend(
                [total_ctxs + v for v in batch_positive_idxs])

            logger.info(
                "Av.rank validation: step %d, computed ctx_vectors %d, q_vectors %d",
                i,
                len(ctx_represenations),
                len(q_represenations),
            )

        ctx_represenations = torch.cat(ctx_represenations, dim=0)
        q_represenations = torch.cat(q_represenations, dim=0)

        if cfg.others.is_matching:
            # Need to compute by batch of contexts
            logger.info("Average rank validation for interaction layers...")
            num_batches = math.ceil(len(ctx_represenations) / sub_batch_size)
            interaction_scores = []
            for i in range(num_batches):
                start = i * sub_batch_size
                end = min(start + sub_batch_size, len(ctx_represenations))
                with torch.no_grad():
                    interaction_score = self.biencoder(
                        q_pooled_out=q_represenations.to(cfg.device),
                        ctx_pooled_out=ctx_represenations[start:end].to(
                            cfg.device),
                        is_matching=True).cpu()
                    interaction_scores.append(interaction_score)
                logger.info("Av.rank validation (interaction): step %d/%d", i,
                            num_batches)

            interaction_scores = torch.cat(
                interaction_scores, dim=1)  # concatenate along context dim
            logger.info(
                "Av.rank validation (interaction): total interaction matrix size=%s",
                interaction_scores.size())

        logger.info("Av.rank validation: total q_vectors size=%s",
                    q_represenations.size())
        logger.info("Av.rank validation: total ctx_vectors size=%s",
                    ctx_represenations.size())

        # Calculate cosine similarity scores
        q_num = q_represenations.size(0)
        assert q_num == len(positive_idx_per_question)

        scores = sim_score_f(q_represenations, ctx_represenations)
        values, indices = torch.sort(scores, dim=1, descending=True)

        rank = 0
        for i, idx in enumerate(positive_idx_per_question):
            # aggregate the rank of the known gold passage in the sorted results for each question
            gold_idx = (indices[i] == idx).nonzero()
            rank += gold_idx.item()

        if distributed_factor > 1:
            # each node calcuated its own rank, exchange the information between node and calculate the "global" average rank
            # NOTE: the set of passages is still unique for every node
            eval_stats = all_gather_list([rank, q_num], max_size=100)
            for i, item in enumerate(eval_stats):
                remote_rank, remote_q_num = item
                if i != cfg.local_rank:
                    rank += remote_rank
                    q_num += remote_q_num

        av_rank = float(rank / q_num)
        logger.info("Av.rank validation: average rank %s, total questions=%d",
                    av_rank, q_num)

        # Calculate interaction scores
        if cfg.others.is_matching:
            interaction_q_num = q_represenations.size(0)
            assert interaction_q_num == len(positive_idx_per_question)

            interaction_rank = 0
            values, indices = torch.sort(interaction_scores,
                                         dim=1,
                                         descending=True)
            for i, idx in enumerate(positive_idx_per_question):
                # aggregate the rank of the known gold passage in the sorted results for each question
                gold_idx = (indices[i] == idx).nonzero()
                interaction_rank += gold_idx.item()

            if distributed_factor > 1:
                # each node calcuated its own rank, exchange the information between node and calculate the "global" average rank
                # NOTE: the set of passages is still unique for every node
                eval_stats = all_gather_list(
                    [interaction_rank, interaction_q_num], max_size=100)
                for i, item in enumerate(eval_stats):
                    remote_rank, remote_q_num = item
                    if i != cfg.local_rank:
                        interaction_rank += remote_rank
                        interaction_q_num += remote_q_num

            interaction_av_rank = float(interaction_rank / interaction_q_num)
            logger.info(
                "Av.rank validation (interaction): average rank %s, total questions=%d",
                interaction_av_rank, interaction_q_num)

        return interaction_av_rank if cfg.others.is_matching else av_rank
Esempio n. 7
0
    def validate_nll(self) -> float:
        logger.info("NLL validation ...")
        cfg = self.cfg
        self.biencoder.eval()

        if not self.dev_iterator:
            self.dev_iterator = self.get_data_iterator(
                cfg.train.dev_batch_size,
                False,
                shuffle=False,
                rank=cfg.local_rank)
        data_iterator = self.dev_iterator

        total_loss = 0.0
        start_time = time.time()
        total_correct_predictions, total_correct_predictions_matching = 0, 0
        num_hard_negatives = cfg.train.hard_negatives
        num_other_negatives = cfg.train.other_negatives
        log_result_step = cfg.train.log_batch_step
        batches = 0
        dataset = 0

        for i, samples_batch in enumerate(data_iterator.iterate_ds_data()):
            if isinstance(samples_batch, Tuple):
                samples_batch, dataset = samples_batch
            logger.info("Eval step: %d ,rnk=%s", i, cfg.local_rank)
            biencoder_input = BiEncoder.create_biencoder_input(
                samples_batch,
                self.tensorizer,
                insert_title=True,
                num_hard_negatives=num_hard_negatives,
                num_other_negatives=num_other_negatives,
                shuffle=False,
            )

            # get the token to be used for representation selection
            ds_cfg = self.ds_cfg.dev_datasets[dataset]
            rep_positions_q = ds_cfg.selector.get_positions(
                biencoder_input.question_ids, self.tensorizer, self.biencoder)
            rep_positions_c = ds_cfg.selector.get_positions(
                biencoder_input.context_ids, self.tensorizer, self.biencoder)
            encoder_type = ds_cfg.encoder_type

            outp = _do_biencoder_fwd_pass(
                self.biencoder,
                biencoder_input,
                self.tensorizer,
                self.loss_function,
                cfg,
                encoder_type=encoder_type,
                rep_positions_q=rep_positions_q,
                rep_positions_c=rep_positions_c,
            )
            if cfg.others.is_matching:
                loss, correct_cnt, correct_cnt_matching = outp
                total_correct_predictions_matching += correct_cnt_matching
            else:
                loss, correct_cnt = outp

            total_loss += loss.item()
            total_correct_predictions += correct_cnt
            batches += 1
            if (i + 1) % log_result_step == 0:
                logger.info(
                    "Eval step: %d , used_time=%f sec., loss=%f ",
                    i,
                    time.time() - start_time,
                    loss.item(),
                )

        total_loss = total_loss / batches
        total_samples = batches * cfg.train.dev_batch_size * self.distributed_factor
        correct_ratio = float(total_correct_predictions / total_samples)
        to_log = (
            f"NLL Validation: loss = {total_loss:.4f} correct prediction ratio  "
            f"{total_correct_predictions}/{total_samples} ~ {correct_ratio:.4f}"
        )

        if cfg.others.is_matching:
            correct_ratio_matching = float(total_correct_predictions_matching /
                                           total_samples)
            to_log += (
                f", matching correct prediction ratio {total_correct_predictions_matching}/{total_samples}"
                f" ~ {correct_ratio_matching:.4f}")
        logger.info(to_log)

        return total_loss