def _train_epoch( self, scheduler, epoch: int, eval_step: int, train_data_iterator: ShardedDataIterator, global_step: int, ): cfg = self.cfg rolling_train_loss = 0.0 epoch_loss = 0 log_result_step = cfg.train.log_batch_step rolling_loss_step = cfg.train.train_rolling_loss_step self.reader.train() epoch_batches = train_data_iterator.max_iterations for i, samples_batch in enumerate( train_data_iterator.iterate_ds_data(epoch=epoch)): data_iteration = train_data_iterator.get_iteration() # enables to resume to exactly same train state if cfg.fully_resumable: np.random.seed(cfg.seed + global_step) torch.manual_seed(cfg.seed + global_step) if cfg.n_gpu > 0: torch.cuda.manual_seed_all(cfg.seed + global_step) input = create_reader_input( self.tensorizer.get_pad_id(), samples_batch, cfg.passages_per_question, cfg.encoder.sequence_length, cfg.max_n_answers, is_train=True, shuffle=True, ) loss = self._calc_loss(input) epoch_loss += loss.item() rolling_train_loss += loss.item() max_grad_norm = cfg.train.max_grad_norm if cfg.fp16: from apex import amp with amp.scale_loss(loss, self.optimizer) as scaled_loss: scaled_loss.backward() if max_grad_norm > 0: torch.nn.utils.clip_grad_norm_( amp.master_params(self.optimizer), max_grad_norm) else: loss.backward() if max_grad_norm > 0: torch.nn.utils.clip_grad_norm_(self.reader.parameters(), max_grad_norm) if (i + 1) % cfg.train.gradient_accumulation_steps == 0: self.optimizer.step() scheduler.step() self.reader.zero_grad() global_step += 1 if i % log_result_step == 0: lr = self.optimizer.param_groups[0]["lr"] logger.info( "Epoch: %d: Step: %d/%d, global_step=%d, lr=%f", epoch, data_iteration, epoch_batches, global_step, lr, ) if (i + 1) % rolling_loss_step == 0: logger.info("Train batch %d", data_iteration) latest_rolling_train_av_loss = rolling_train_loss / rolling_loss_step logger.info( "Avg. loss per last %d batches: %f", rolling_loss_step, latest_rolling_train_av_loss, ) rolling_train_loss = 0.0 if global_step % eval_step == 0: logger.info( "Validation: Epoch: %d Step: %d/%d", epoch, data_iteration, epoch_batches, ) self.validate_and_save(epoch, train_data_iterator.get_iteration(), scheduler) self.reader.train() epoch_loss = (epoch_loss / epoch_batches) if epoch_batches > 0 else 0 logger.info("Av Loss per epoch=%f", epoch_loss) return global_step
def _train_epoch( self, scheduler, epoch: int, eval_step: int, train_data_iterator: ShardedDataIterator, ): args = self.args rolling_train_loss = 0.0 epoch_loss = 0 epoch_correct_predictions = 0 log_result_step = args.log_batch_step rolling_loss_step = args.train_rolling_loss_step num_hard_negatives = args.hard_negatives num_other_negatives = args.other_negatives seed = args.seed self.biencoder.train() epoch_batches = train_data_iterator.max_iterations data_iteration = 0 for i, samples_batch in enumerate( train_data_iterator.iterate_data(epoch=epoch) ): # to be able to resume shuffled ctx- pools data_iteration = train_data_iterator.get_iteration() random.seed(seed + epoch + data_iteration) biencoder_batch = BiEncoder.create_biencoder_input( samples_batch, self.tensorizer, True, num_hard_negatives, num_other_negatives, shuffle=True, shuffle_positives=args.shuffle_positive_ctx, ) loss, correct_cnt = _do_biencoder_fwd_pass( self.biencoder, biencoder_batch, self.tensorizer, args ) epoch_correct_predictions += correct_cnt epoch_loss += loss.item() rolling_train_loss += loss.item() if args.fp16: from apex import amp with amp.scale_loss(loss, self.optimizer) as scaled_loss: scaled_loss.backward() if args.max_grad_norm > 0: torch.nn.utils.clip_grad_norm_( amp.master_params(self.optimizer), args.max_grad_norm ) else: loss.backward() if args.max_grad_norm > 0: torch.nn.utils.clip_grad_norm_( self.biencoder.parameters(), args.max_grad_norm ) if (i + 1) % args.gradient_accumulation_steps == 0: self.optimizer.step() scheduler.step() self.biencoder.zero_grad() if i % log_result_step == 0: lr = self.optimizer.param_groups[0]["lr"] logger.info( "Epoch: %d: Step: %d/%d, loss=%f, lr=%f", epoch, data_iteration, epoch_batches, loss.item(), lr, ) if (i + 1) % rolling_loss_step == 0: logger.info("Train batch %d", data_iteration) latest_rolling_train_av_loss = rolling_train_loss / rolling_loss_step logger.info( "Avg. loss per last %d batches: %f", rolling_loss_step, latest_rolling_train_av_loss, ) rolling_train_loss = 0.0 if data_iteration % eval_step == 0: logger.info( "Validation: Epoch: %d Step: %d/%d", epoch, data_iteration, epoch_batches, ) self.validate_and_save( epoch, train_data_iterator.get_iteration(), scheduler ) self.biencoder.train() self.validate_and_save(epoch, data_iteration, scheduler) epoch_loss = (epoch_loss / epoch_batches) if epoch_batches > 0 else 0 logger.info("Av Loss per epoch=%f", epoch_loss) logger.info("epoch total correct predictions=%d", epoch_correct_predictions)