Пример #1
0
    def _initial_performance(self, session):
        val_start = dt.now()
        validation_performance = session.test_model(self.val_loader)
        target_performance = Evaluator.means_over_subsets(
            validation_performance)['distance']
        val_time = dt.now() - val_start

        Evaluator.print_result_summary_flat(validation_performance, '\t')
        print('\t\tThat took {} milliseconds.'.format(val_time.microseconds //
                                                      1000))

        return target_performance, copy.deepcopy(session.model.state_dict())
Пример #2
0
def test_means_over_subsets():
    true_mean_results = {
        'coord_diff': torch.tensor(1.5),
        'distance': torch.tensor(1.5),
        'bone_length': torch.tensor(2.0),
        'proportion': torch.tensor(2.0),
    }

    mean_results = Evaluator.means_over_subsets(example_results)
    for metric_name in Evaluator.metric_names:
        assert torch.allclose(mean_results[metric_name],
                              true_mean_results[metric_name])
Пример #3
0
    def train(self, model, hyperparams):
        """
        Trains the passed model on the training set with the specified hyper-parameters.
        Loss, validation errors or other intermediate results are logged (or printed/plotted during
        the training) and returned at the end, together with the best weights according to the
        validation performance.

        :param model: Model to train.
        :type model: torch.nn.Module

        :param hyperparams: Dictionary with all the hyperparameters required for training.
        :type hyperparams: dict

        :return: log: Dictionary containing all logs collected during training.
        :type: log: dict

        :return: final_val_results: The validation results (all metrics) of the model using the best
                                    weights after training finished.
        :type: final_val_results: dict

        :return: best_weights: Model weights that performed best on the validation set during the
                               whole training.
        :type: best_weights: dict

        :return: example_predictions: Corrected example poses from the validation set collected
                                      during training.
        :type: example_predictions: torch.FloatTensor
        """
        start_time = dt.now()
        print('Time: ', start_time.strftime('%H:%M:%S'))
        print()

        print('Setting things up...')

        session = TrainingSession(model, hyperparams, self.normalizer)
        self.train_loader.set_augmenters(hyperparams['augmenters'])

        helper.print_hyperparameters(hyperparams,
                                     self.config['interest_keys'],
                                     indent=1)
        log, example_predictions, log_iterations = self._initialize_logs()

        print('\n\tChecking initial validation performance:')
        best_val_performance, best_weights = self._initial_performance(session)

        print()
        print('All set, let\'s get started!')
        for epoch in range(self.config['num_epochs']):
            running_loss = 0.0
            session.schedule_learning_rate()
            for i, batch in enumerate(self.train_loader):
                loss, output_batch = session.train_batch(batch)

                if self.config['log_loss']:
                    log['train']['loss'].append(loss)

                if self.config['log_grad']:
                    log['train']['grad'].append(self._sum_gradients(model))

                if i in log_iterations:
                    train_performance, val_performance = self._intermediate_eval(
                        session, output_batch)
                    self._logging(log, loss, train_performance,
                                  val_performance, i)
                    if len(self.config['val_example_indices']) > 0:
                        example_predictions.append(
                            self._example_predictions(session))

                running_loss += loss

            # Training for this epoch finished.
            session.scheduler_metric = running_loss / self.iters_per_epoch

            val_performance = session.test_model(self.val_loader)
            target_performance = Evaluator.means_over_subsets(
                val_performance)['distance']
            if target_performance < best_val_performance:
                best_val_performance = target_performance
                best_weights = copy.deepcopy(model.state_dict())
                log['best_epoch'] = epoch

            self._print_epoch_end_info(session, val_performance, start_time,
                                       epoch, best_val_performance)

        model.load_state_dict(best_weights)
        final_val_results = self._full_evaluation(model,
                                                  session.params['eval_space'])
        print()
        print('-' * 30)
        print('FINISH')
        print('Final validation errors:')
        Evaluator.print_results(final_val_results)
        print()

        return log, final_val_results, best_weights, torch.stack(
            example_predictions).cpu()