def train_main(self, cache=False): print("\n\n" + "=" * 100 + "\n\t\t\t\t\t Training Network\n" + "=" * 100) self.start = time.time() print("\nBeginning training at: {} \n".format( datetime.datetime.now())) self.model.to(self.device) for self.epoch in range(self.start_epoch, self.config['max_epoch'] + 1): train_times = [] for self.iters, self.batch in enumerate( self.config['train_loader']): self.model.train() iter_time = time.time() self.batch = self.batch_to_device(self.batch) self.train_iter_step() train_times.append(time.time() - iter_time) # Loss only if (self.total_iters + self.iters + 1) % self.config['log_every'] == 0: ## Uncomment line below for debugging if self.config['debug']: LOGGER.info( "Logging tensorboard at step %i with %i values" % (self.iters + self.total_iters + 1, len(self.short_loss_list))) log_tensorboard(self.config, self.config['writer'], self.model, self.epoch, self.iters, self.total_iters, self.short_loss_list, loss_only=True, val=False) self.config['writer'].add_scalar( 'Stats/time_per_train_iter', mean(train_times), (self.iters + self.total_iters + 1)) self.config['writer'].add_scalar( 'Stats/learning_rate', self.scheduler.get_last_lr()[0], (self.iters + self.total_iters + 1)) train_times = [] self.short_loss_list = [] self.train_epoch_step() if self.terminate_training: break self.end_training() return self.best_val_metrics, self.test_metrics
def train_epoch_step(self): self.model.train() lr = self.scheduler.get_last_lr() self.total_iters += self.iters + 1 self.probs_list = [ pred for batch_pred in self.probs_list for pred in batch_pred ] self.labels_list = [ label for batch_labels in self.labels_list for label in batch_labels ] # Evaluate on train set self.train_metrics = standard_metrics(torch.tensor(self.probs_list), torch.tensor(self.labels_list), add_optimal_acc=True) log_tensorboard(self.config, self.config['writer'], self.model, self.epoch, self.iters, self.total_iters, self.loss_list, self.train_metrics, lr[0], loss_only=False, val=False) self.train_loss = self.loss_list[:] # Evaluate on dev set val_time = time.time() self.val_metrics, self.val_loss = self.eval_model() self.config['writer'].add_scalar("Stats/time_validation", time.time() - val_time, self.total_iters) # print stats print_stats(self.config, self.epoch, self.train_metrics, self.train_loss, self.val_metrics, self.val_loss, self.start, lr[0]) # log validation stats in tensorboard log_tensorboard(self.config, self.config['writer'], self.model, self.epoch, self.iters, self.total_iters, self.val_loss, self.val_metrics, lr[0], loss_only=False, val=True) # Check for early stopping criteria self.check_early_stopping() self.probs_list = [] self.preds_list = [] self.labels_list = [] self.loss_list = [] self.id_list = [] self.train_loss = sum(self.train_loss) / len(self.train_loss) del self.val_metrics del self.val_loss