def _train_epoch( self, scheduler, epoch: int, eval_step: int, train_data_iterator: ShardedDataIterator, ): args = self.args rolling_train_loss = 0.0 epoch_loss = 0 epoch_correct_predictions = 0 log_result_step = args.log_batch_step rolling_loss_step = args.train_rolling_loss_step num_hard_negatives = args.hard_negatives num_other_negatives = args.other_negatives seed = args.seed self.biencoder.train() epoch_batches = train_data_iterator.max_iterations data_iteration = 0 for i, samples_batch in enumerate( train_data_iterator.iterate_data(epoch=epoch) ): # to be able to resume shuffled ctx- pools data_iteration = train_data_iterator.get_iteration() random.seed(seed + epoch + data_iteration) biencoder_batch = BiEncoder.create_biencoder_input( samples_batch, self.tensorizer, True, num_hard_negatives, num_other_negatives, shuffle=True, shuffle_positives=args.shuffle_positive_ctx, ) loss, correct_cnt = _do_biencoder_fwd_pass( self.biencoder, biencoder_batch, self.tensorizer, args ) epoch_correct_predictions += correct_cnt epoch_loss += loss.item() rolling_train_loss += loss.item() if args.fp16: from apex import amp with amp.scale_loss(loss, self.optimizer) as scaled_loss: scaled_loss.backward() if args.max_grad_norm > 0: torch.nn.utils.clip_grad_norm_( amp.master_params(self.optimizer), args.max_grad_norm ) else: loss.backward() if args.max_grad_norm > 0: torch.nn.utils.clip_grad_norm_( self.biencoder.parameters(), args.max_grad_norm ) if (i + 1) % args.gradient_accumulation_steps == 0: self.optimizer.step() scheduler.step() self.biencoder.zero_grad() if i % log_result_step == 0: lr = self.optimizer.param_groups[0]["lr"] logger.info( "Epoch: %d: Step: %d/%d, loss=%f, lr=%f", epoch, data_iteration, epoch_batches, loss.item(), lr, ) if (i + 1) % rolling_loss_step == 0: logger.info("Train batch %d", data_iteration) latest_rolling_train_av_loss = rolling_train_loss / rolling_loss_step logger.info( "Avg. loss per last %d batches: %f", rolling_loss_step, latest_rolling_train_av_loss, ) rolling_train_loss = 0.0 if data_iteration % eval_step == 0: logger.info( "Validation: Epoch: %d Step: %d/%d", epoch, data_iteration, epoch_batches, ) self.validate_and_save( epoch, train_data_iterator.get_iteration(), scheduler ) self.biencoder.train() self.validate_and_save(epoch, data_iteration, scheduler) epoch_loss = (epoch_loss / epoch_batches) if epoch_batches > 0 else 0 logger.info("Av Loss per epoch=%f", epoch_loss) logger.info("epoch total correct predictions=%d", epoch_correct_predictions)
def _train_epoch(self, scheduler, epoch: int, eval_step: int, train_data_iterator: ShardedDataIterator, global_step: int): args = self.args rolling_train_loss = 0.0 epoch_loss = 0 log_result_step = args.log_batch_step rolling_loss_step = args.train_rolling_loss_step self.reader.train() epoch_batches = train_data_iterator.max_iterations for i, samples_batch in enumerate( train_data_iterator.iterate_data(epoch=epoch)): data_iteration = train_data_iterator.get_iteration() # enables to resume to exactly same train state if args.fully_resumable: np.random.seed(args.seed + global_step) torch.manual_seed(args.seed + global_step) if args.n_gpu > 0: torch.cuda.manual_seed_all(args.seed + global_step) input = create_reader_input(self.tensorizer.get_pad_id(), samples_batch, args.passages_per_question, args.sequence_length, args.max_n_answers, is_train=True, shuffle=True) loss = self._calc_loss(input) epoch_loss += loss.item() rolling_train_loss += loss.item() if args.fp16: from apex import amp with amp.scale_loss(loss, self.optimizer) as scaled_loss: scaled_loss.backward() if args.max_grad_norm > 0: torch.nn.utils.clip_grad_norm_( amp.master_params(self.optimizer), args.max_grad_norm) else: loss.backward() if args.max_grad_norm > 0: torch.nn.utils.clip_grad_norm_(self.reader.parameters(), args.max_grad_norm) global_step += 1 if (i + 1) % args.gradient_accumulation_steps == 0: self.optimizer.step() scheduler.step() self.reader.zero_grad() if global_step % log_result_step == 0: lr = self.optimizer.param_groups[0]['lr'] logger.info('Epoch: %d: Step: %d/%d, global_step=%d, lr=%f', epoch, data_iteration, epoch_batches, global_step, lr) if (i + 1) % rolling_loss_step == 0: logger.info('Train batch %d', data_iteration) latest_rolling_train_av_loss = rolling_train_loss / rolling_loss_step logger.info('Avg. loss per last %d batches: %f', rolling_loss_step, latest_rolling_train_av_loss) rolling_train_loss = 0.0 if global_step % eval_step == 0: logger.info('Validation: Epoch: %d Step: %d/%d', epoch, data_iteration, epoch_batches) self.validate_and_save(epoch, train_data_iterator.get_iteration(), scheduler) self.reader.train() epoch_loss = (epoch_loss / epoch_batches) if epoch_batches > 0 else 0 logger.info('Av Loss per epoch=%f', epoch_loss) return global_step