def train_val(self): """ Function that does the training and validation :return: """ for epoch in range(self.start_epoch, self.config.n_epochs): self.pose_p = None # VALIDATION if self.config.do_val and ((epoch % self.config.val_freq == 0) or (epoch == self.config.n_epochs - 1)): errs = self.eval(epoch) print('Val {:s}: Epoch {:d}'.format(self.experiment, epoch)) print(errs) # SAVE CHECKPOINT if epoch % self.config.snapshot == 0: self.save_checkpoint(epoch) print('Epoch {:d} checkpoint saved for {:s}'.format( epoch, self.experiment)) # ADJUST LR if not self.sep_train: lr = self.optimizer.adjust_lr(epoch) else: lr = self.pose_optimizer.adjust_lr(epoch) # TRAIN self.model.train() train_data_time = Logger.AverageMeter() train_batch_time = Logger.AverageMeter() end = time.time() self.pose_p = None for batch_idx, (data, target) in enumerate(self.train_loader): train_data_time.update(time.time() - end) if (not self.sep_train): kwargs = dict(target=target, criterion=self.train_criterion, optim=self.optimizer, train=True, max_grad_norm=self.config.max_grad_norm, task='both') loss, _ = self.step_feedfwd(data, self.model, self.config.GPUs > 0, **kwargs) else: # optimize odometry net kwargs = dict(target=target, criterion=self.train_criterion, optim=self.odom_optimizer, train=True, max_grad_norm=self.config.max_grad_norm, task='odom') loss, _ = self.step_feedfwd(data, self.model, self.config.GPUs > 0, **kwargs) # optimize global pose net kwargs['optim'] = self.pose_optimizer kwargs['task'] = 'pose' loss, _ = self.step_feedfwd(data, self.model, self.config.GPUs > 0, **kwargs) train_batch_time.update(time.time() - end) if batch_idx % self.config.print_freq == 0: n_iter = epoch * len(self.train_loader) + batch_idx epoch_count = float(n_iter) / len(self.train_loader) print('Train {:s}: Epoch {:d}\t' 'Batch {:d}/{:d}\t' 'Data Time {:.4f} ({:.4f})\t' 'Batch Time {:.4f} ({:.4f})\t' 'Loss {:f}\t' 'lr: {:f}'.format(self.experiment, epoch, batch_idx, len(self.train_loader) - 1, train_data_time.val, train_data_time.avg, train_batch_time.val, train_batch_time.avg, loss, lr)) if batch_idx % self.config.summary_freq == 0: # print(data.cpu().data.numpy()) scalar_names_vars = [ ('loss', loss), ('sx_abs', self.train_criterion.sx_abs), ('sq_abs', self.train_criterion.sq_abs), ('sx_rel', self.train_criterion.sx_rel), ('sq_rel', self.train_criterion.sq_rel), ('sx_vo', self.train_criterion.sx_vo), ('sq_vo', self.train_criterion.sq_vo), ] image_names_vars = [('input', data[0])] histogram_names_vars = [] summarize(self.summary_writer, scalar_names_vars, image_names_vars, histogram_names_vars, batch_idx + epoch * len(self.train_loader)) end = time.time() # Save final checkpoint epoch = self.config.n_epochs self.save_checkpoint(epoch) print('Epoch {:d} checkpoint saved'.format(epoch))