def _initial_performance(self, session): val_start = dt.now() validation_performance = session.test_model(self.val_loader) target_performance = Evaluator.means_over_subsets( validation_performance)['distance'] val_time = dt.now() - val_start Evaluator.print_result_summary_flat(validation_performance, '\t') print('\t\tThat took {} milliseconds.'.format(val_time.microseconds // 1000)) return target_performance, copy.deepcopy(session.model.state_dict())
def test_means_over_subsets(): true_mean_results = { 'coord_diff': torch.tensor(1.5), 'distance': torch.tensor(1.5), 'bone_length': torch.tensor(2.0), 'proportion': torch.tensor(2.0), } mean_results = Evaluator.means_over_subsets(example_results) for metric_name in Evaluator.metric_names: assert torch.allclose(mean_results[metric_name], true_mean_results[metric_name])
def train(self, model, hyperparams): """ Trains the passed model on the training set with the specified hyper-parameters. Loss, validation errors or other intermediate results are logged (or printed/plotted during the training) and returned at the end, together with the best weights according to the validation performance. :param model: Model to train. :type model: torch.nn.Module :param hyperparams: Dictionary with all the hyperparameters required for training. :type hyperparams: dict :return: log: Dictionary containing all logs collected during training. :type: log: dict :return: final_val_results: The validation results (all metrics) of the model using the best weights after training finished. :type: final_val_results: dict :return: best_weights: Model weights that performed best on the validation set during the whole training. :type: best_weights: dict :return: example_predictions: Corrected example poses from the validation set collected during training. :type: example_predictions: torch.FloatTensor """ start_time = dt.now() print('Time: ', start_time.strftime('%H:%M:%S')) print() print('Setting things up...') session = TrainingSession(model, hyperparams, self.normalizer) self.train_loader.set_augmenters(hyperparams['augmenters']) helper.print_hyperparameters(hyperparams, self.config['interest_keys'], indent=1) log, example_predictions, log_iterations = self._initialize_logs() print('\n\tChecking initial validation performance:') best_val_performance, best_weights = self._initial_performance(session) print() print('All set, let\'s get started!') for epoch in range(self.config['num_epochs']): running_loss = 0.0 session.schedule_learning_rate() for i, batch in enumerate(self.train_loader): loss, output_batch = session.train_batch(batch) if self.config['log_loss']: log['train']['loss'].append(loss) if self.config['log_grad']: log['train']['grad'].append(self._sum_gradients(model)) if i in log_iterations: train_performance, val_performance = self._intermediate_eval( session, output_batch) self._logging(log, loss, train_performance, val_performance, i) if len(self.config['val_example_indices']) > 0: example_predictions.append( self._example_predictions(session)) running_loss += loss # Training for this epoch finished. session.scheduler_metric = running_loss / self.iters_per_epoch val_performance = session.test_model(self.val_loader) target_performance = Evaluator.means_over_subsets( val_performance)['distance'] if target_performance < best_val_performance: best_val_performance = target_performance best_weights = copy.deepcopy(model.state_dict()) log['best_epoch'] = epoch self._print_epoch_end_info(session, val_performance, start_time, epoch, best_val_performance) model.load_state_dict(best_weights) final_val_results = self._full_evaluation(model, session.params['eval_space']) print() print('-' * 30) print('FINISH') print('Final validation errors:') Evaluator.print_results(final_val_results) print() return log, final_val_results, best_weights, torch.stack( example_predictions).cpu()