def _train_epoch( self, scheduler, epoch: int, eval_step, train_data_iterator: MultiSetDataIterator, ): cfg = self.cfg rolling_train_loss = 0.0 epoch_loss = 0 epoch_correct_predictions = 0 log_result_step = cfg.train.log_batch_step rolling_loss_step = cfg.train.train_rolling_loss_step num_hard_negatives = cfg.train.hard_negatives num_other_negatives = cfg.train.other_negatives num_related = cfg.train.related num_highly_related = cfg.train.highly_related seed = cfg.seed self.biencoder.train() epoch_batches = train_data_iterator.max_iterations data_iteration = 0 dataset = 0 for i, samples_batch in enumerate( train_data_iterator.iterate_ds_data(epoch=epoch) ): if isinstance(samples_batch, Tuple): samples_batch, dataset = samples_batch ds_cfg = self.ds_cfg.train_datasets[dataset] special_token = ds_cfg.special_token encoder_type = ds_cfg.encoder_type shuffle_positives = ds_cfg.shuffle_positives # to be able to resume shuffled ctx- pools data_iteration = train_data_iterator.get_iteration() random.seed(seed + epoch + data_iteration) if self.trainer_type == 'graded': biencoder_batch = BiEncoder.create_graded_biencoder_input2( samples_batch, self.tensorizer, True, num_hard_negatives, num_other_negatives, num_related, num_highly_related, shuffle=True, shuffle_positives=shuffle_positives, query_token=special_token, relation_grades=self.relations, ) else: biencoder_batch = BiEncoder.create_biencoder_input2( samples_batch, self.tensorizer, True, num_hard_negatives, num_other_negatives, shuffle=True, shuffle_positives=shuffle_positives, query_token=special_token, ) # get the token to be used for representation selection from dpr.data.biencoder_data import DEFAULT_SELECTOR selector = ds_cfg.selector if ds_cfg else DEFAULT_SELECTOR rep_positions = selector.get_positions( biencoder_batch.question_ids, self.tensorizer ) loss_scale = ( cfg.loss_scale_factors[dataset] if cfg.loss_scale_factors else None ) if self.trainer_type == 'graded': loss, correct_cnt = _do_biencoder_fwd_pass_graded( self.biencoder, biencoder_batch, self.tensorizer, cfg, encoder_type=encoder_type, rep_positions=rep_positions, loss_scale=loss_scale, ) else: loss, correct_cnt = _do_biencoder_fwd_pass( self.biencoder, biencoder_batch, self.tensorizer, cfg, encoder_type=encoder_type, rep_positions=rep_positions, loss_scale=loss_scale, ) epoch_correct_predictions += correct_cnt epoch_loss += loss.item() rolling_train_loss += loss.item() if cfg.fp16: from apex import amp with amp.scale_loss(loss, self.optimizer) as scaled_loss: scaled_loss.backward() if cfg.train.max_grad_norm > 0: torch.nn.utils.clip_grad_norm_( amp.master_params(self.optimizer), cfg.train.max_grad_norm ) else: loss.backward() if cfg.train.max_grad_norm > 0: torch.nn.utils.clip_grad_norm_( self.biencoder.parameters(), cfg.train.max_grad_norm ) if (i + 1) % cfg.train.gradient_accumulation_steps == 0: self.optimizer.step() scheduler.step() self.biencoder.zero_grad() if i % log_result_step == 0: lr = self.optimizer.param_groups[0]["lr"] logger.info( "Epoch: %d: Step: %d/%d, loss=%f, lr=%f", epoch, data_iteration, epoch_batches, loss.item(), lr, ) if (i + 1) % rolling_loss_step == 0: logger.info("Train batch %d", data_iteration) latest_rolling_train_av_loss = rolling_train_loss / rolling_loss_step logger.info( "Avg. loss per last %d batches: %f", rolling_loss_step, latest_rolling_train_av_loss, ) rolling_train_loss = 0.0 # if data_iteration % eval_step == 0: # logger.info( # "rank=%d, Validation: Epoch: %d Step: %d/%d", # cfg.local_rank, # epoch, # data_iteration, # epoch_batches, # ) # self.validate_and_save( # epoch, train_data_iterator.get_iteration(), scheduler # ) # self.biencoder.train() if isinstance(eval_step, int): if (epoch + 1) % eval_step == 0: self.validate_and_save(epoch, data_iteration, scheduler) else: # omegaconf.listconfig.ListConfig if epoch in eval_step: self.validate_and_save(epoch, data_iteration, scheduler) epoch_loss = (epoch_loss / epoch_batches) if epoch_batches > 0 else 0 logger.info("Av Loss per epoch=%f", epoch_loss) logger.info("epoch total correct predictions=%d", epoch_correct_predictions)
def validate_nll(self) -> float: logger.info("NLL validation ...") cfg = self.cfg self.biencoder.eval() if not self.dev_iterator: self.dev_iterator = self.get_data_iterator( cfg.train.dev_batch_size, False, shuffle=False, rank=cfg.local_rank ) data_iterator = self.dev_iterator total_loss = 0.0 start_time = time.time() total_correct_predictions = 0 num_hard_negatives = cfg.train.hard_negatives num_other_negatives = cfg.train.other_negatives num_related = cfg.train.related num_highly_related = cfg.train.highly_related log_result_step = cfg.train.log_batch_step batches = 0 dataset = 0 for i, samples_batch in enumerate(data_iterator.iterate_ds_data()): if isinstance(samples_batch, Tuple): samples_batch, dataset = samples_batch logger.info("Eval step: %d ,rnk=%s", i, cfg.local_rank) if self.trainer_type == 'graded': biencoder_input = BiEncoder.create_graded_biencoder_input2( samples_batch, self.tensorizer, True, num_hard_negatives, num_other_negatives, num_related, num_highly_related, shuffle=False, relation_grades=self.relations, ) else: biencoder_input = BiEncoder.create_biencoder_input2( samples_batch, self.tensorizer, True, num_hard_negatives, num_other_negatives, shuffle=False, ) # get the token to be used for representation selection ds_cfg = self.ds_cfg.dev_datasets[dataset] rep_positions = ds_cfg.selector.get_positions( biencoder_input.question_ids, self.tensorizer ) encoder_type = ds_cfg.encoder_type if self.trainer_type == 'graded': loss, correct_cnt = _do_biencoder_fwd_pass_graded( self.biencoder, biencoder_input, self.tensorizer, cfg, encoder_type=encoder_type, rep_positions=rep_positions, ) else: loss, correct_cnt = _do_biencoder_fwd_pass( self.biencoder, biencoder_input, self.tensorizer, cfg, encoder_type=encoder_type, rep_positions=rep_positions, ) total_loss += loss.item() total_correct_predictions += correct_cnt batches += 1 if (i + 1) % log_result_step == 0: logger.info( "Eval step: %d , used_time=%f sec., loss=%f ", i, time.time() - start_time, loss.item(), ) total_loss = total_loss / batches total_samples = batches * cfg.train.dev_batch_size * self.distributed_factor correct_ratio = float(total_correct_predictions / total_samples) logger.info( "NLL Validation: loss = %f. correct prediction ratio %d/%d ~ %f", total_loss, total_correct_predictions, total_samples, correct_ratio, ) return total_loss