Beispiel #1
0
    def _train_epoch(
        self,
        scheduler,
        epoch: int,
        eval_step,
        train_data_iterator: MultiSetDataIterator,
    ):

        cfg = self.cfg
        rolling_train_loss = 0.0
        epoch_loss = 0
        epoch_correct_predictions = 0

        log_result_step = cfg.train.log_batch_step
        rolling_loss_step = cfg.train.train_rolling_loss_step
        num_hard_negatives = cfg.train.hard_negatives
        num_other_negatives = cfg.train.other_negatives
        num_related = cfg.train.related
        num_highly_related = cfg.train.highly_related
        seed = cfg.seed
        self.biencoder.train()
        epoch_batches = train_data_iterator.max_iterations
        data_iteration = 0

        dataset = 0
        for i, samples_batch in enumerate(
            train_data_iterator.iterate_ds_data(epoch=epoch)
        ):
            if isinstance(samples_batch, Tuple):
                samples_batch, dataset = samples_batch

            ds_cfg = self.ds_cfg.train_datasets[dataset]
            special_token = ds_cfg.special_token
            encoder_type = ds_cfg.encoder_type
            shuffle_positives = ds_cfg.shuffle_positives

            # to be able to resume shuffled ctx- pools
            data_iteration = train_data_iterator.get_iteration()
            random.seed(seed + epoch + data_iteration)

            if self.trainer_type == 'graded':
                biencoder_batch = BiEncoder.create_graded_biencoder_input2(
                    samples_batch,
                    self.tensorizer,
                    True,
                    num_hard_negatives,
                    num_other_negatives,
                    num_related,
                    num_highly_related,
                    shuffle=True,
                    shuffle_positives=shuffle_positives,
                    query_token=special_token,
                    relation_grades=self.relations,
                )
            else:
                biencoder_batch = BiEncoder.create_biencoder_input2(
                    samples_batch,
                    self.tensorizer,
                    True,
                    num_hard_negatives,
                    num_other_negatives,
                    shuffle=True,
                    shuffle_positives=shuffle_positives,
                    query_token=special_token,
                )

            # get the token to be used for representation selection
            from dpr.data.biencoder_data import DEFAULT_SELECTOR

            selector = ds_cfg.selector if ds_cfg else DEFAULT_SELECTOR

            rep_positions = selector.get_positions(
                biencoder_batch.question_ids, self.tensorizer
            )

            loss_scale = (
                cfg.loss_scale_factors[dataset] if cfg.loss_scale_factors else None
            )
            if self.trainer_type == 'graded':
                loss, correct_cnt = _do_biencoder_fwd_pass_graded(
                    self.biencoder,
                    biencoder_batch,
                    self.tensorizer,
                    cfg,
                    encoder_type=encoder_type,
                    rep_positions=rep_positions,
                    loss_scale=loss_scale,
                )
            else:
                loss, correct_cnt = _do_biencoder_fwd_pass(
                    self.biencoder,
                    biencoder_batch,
                    self.tensorizer,
                    cfg,
                    encoder_type=encoder_type,
                    rep_positions=rep_positions,
                    loss_scale=loss_scale,
                )

            epoch_correct_predictions += correct_cnt
            epoch_loss += loss.item()
            rolling_train_loss += loss.item()

            if cfg.fp16:
                from apex import amp

                with amp.scale_loss(loss, self.optimizer) as scaled_loss:
                    scaled_loss.backward()
                if cfg.train.max_grad_norm > 0:
                    torch.nn.utils.clip_grad_norm_(
                        amp.master_params(self.optimizer), cfg.train.max_grad_norm
                    )
            else:
                loss.backward()
                if cfg.train.max_grad_norm > 0:
                    torch.nn.utils.clip_grad_norm_(
                        self.biencoder.parameters(), cfg.train.max_grad_norm
                    )

            if (i + 1) % cfg.train.gradient_accumulation_steps == 0:
                self.optimizer.step()
                scheduler.step()
                self.biencoder.zero_grad()

            if i % log_result_step == 0:
                lr = self.optimizer.param_groups[0]["lr"]
                logger.info(
                    "Epoch: %d: Step: %d/%d, loss=%f, lr=%f",
                    epoch,
                    data_iteration,
                    epoch_batches,
                    loss.item(),
                    lr,
                )

            if (i + 1) % rolling_loss_step == 0:
                logger.info("Train batch %d", data_iteration)
                latest_rolling_train_av_loss = rolling_train_loss / rolling_loss_step
                logger.info(
                    "Avg. loss per last %d batches: %f",
                    rolling_loss_step,
                    latest_rolling_train_av_loss,
                )
                rolling_train_loss = 0.0

            # if data_iteration % eval_step == 0:
            #     logger.info(
            #         "rank=%d, Validation: Epoch: %d Step: %d/%d",
            #         cfg.local_rank,
            #         epoch,
            #         data_iteration,
            #         epoch_batches,
            #     )
            #     self.validate_and_save(
            #         epoch, train_data_iterator.get_iteration(), scheduler
            #     )
            #     self.biencoder.train()

        if isinstance(eval_step, int):
            if (epoch + 1) % eval_step == 0:
                self.validate_and_save(epoch, data_iteration, scheduler)
        else:  # omegaconf.listconfig.ListConfig
            if epoch in eval_step:
                self.validate_and_save(epoch, data_iteration, scheduler)

        epoch_loss = (epoch_loss / epoch_batches) if epoch_batches > 0 else 0
        logger.info("Av Loss per epoch=%f", epoch_loss)
        logger.info("epoch total correct predictions=%d", epoch_correct_predictions)
Beispiel #2
0
    def validate_nll(self) -> float:
        logger.info("NLL validation ...")
        cfg = self.cfg
        self.biencoder.eval()

        if not self.dev_iterator:
            self.dev_iterator = self.get_data_iterator(
                cfg.train.dev_batch_size, False, shuffle=False, rank=cfg.local_rank
            )
        data_iterator = self.dev_iterator

        total_loss = 0.0
        start_time = time.time()
        total_correct_predictions = 0
        num_hard_negatives = cfg.train.hard_negatives
        num_other_negatives = cfg.train.other_negatives
        num_related = cfg.train.related
        num_highly_related = cfg.train.highly_related
        log_result_step = cfg.train.log_batch_step
        batches = 0
        dataset = 0

        for i, samples_batch in enumerate(data_iterator.iterate_ds_data()):
            if isinstance(samples_batch, Tuple):
                samples_batch, dataset = samples_batch
            logger.info("Eval step: %d ,rnk=%s", i, cfg.local_rank)
            if self.trainer_type == 'graded':
                biencoder_input = BiEncoder.create_graded_biencoder_input2(
                    samples_batch,
                    self.tensorizer,
                    True,
                    num_hard_negatives,
                    num_other_negatives,
                    num_related,
                    num_highly_related,
                    shuffle=False,
                    relation_grades=self.relations,
                )
            else:
                biencoder_input = BiEncoder.create_biencoder_input2(
                    samples_batch,
                    self.tensorizer,
                    True,
                    num_hard_negatives,
                    num_other_negatives,
                    shuffle=False,
                )
            # get the token to be used for representation selection
            ds_cfg = self.ds_cfg.dev_datasets[dataset]
            rep_positions = ds_cfg.selector.get_positions(
                biencoder_input.question_ids, self.tensorizer
            )
            encoder_type = ds_cfg.encoder_type
            if self.trainer_type == 'graded':
                loss, correct_cnt = _do_biencoder_fwd_pass_graded(
                    self.biencoder,
                    biencoder_input,
                    self.tensorizer,
                    cfg,
                    encoder_type=encoder_type,
                    rep_positions=rep_positions,
                )
            else:
                loss, correct_cnt = _do_biencoder_fwd_pass(
                    self.biencoder,
                    biencoder_input,
                    self.tensorizer,
                    cfg,
                    encoder_type=encoder_type,
                    rep_positions=rep_positions,
                )
            total_loss += loss.item()
            total_correct_predictions += correct_cnt
            batches += 1
            if (i + 1) % log_result_step == 0:
                logger.info(
                    "Eval step: %d , used_time=%f sec., loss=%f ",
                    i,
                    time.time() - start_time,
                    loss.item(),
                )

        total_loss = total_loss / batches
        total_samples = batches * cfg.train.dev_batch_size * self.distributed_factor
        correct_ratio = float(total_correct_predictions / total_samples)
        logger.info(
            "NLL Validation: loss = %f. correct prediction ratio  %d/%d ~  %f",
            total_loss,
            total_correct_predictions,
            total_samples,
            correct_ratio,
        )
        return total_loss