コード例 #1
0
    def test_execute_training_with_early_stopping(self):
        """ Should execute training without errors and save results with correct shapes.
        It should also return a valid checkpoint and early-stopping index.
        Use a simple net with one linear layer and fake-data_loaders with samples of shape (1,4). """
        # create net and setup trainer with fake-DataLoader (use the same loader for training, validation and test)
        net = generate_single_layer_net()
        fake_loader = generate_fake_data_loader()
        trainer = TrainerAdam(0., fake_loader, fake_loader, fake_loader, save_early_stop=True)

        # perform training with mocked validation-loss
        with mock.patch('training.trainer.TrainerAdam.compute_val_loss',
                        side_effect=[2.0, 1.0, 0.5, 1.0]) as mocked_val_loss:
            net, train_loss_hist, val_loss_hist, val_acc_hist, test_acc_hist, early_stop_index, early_stop_cp = \
                trainer.train_net(net, epoch_count=2, plot_step=2)  # 8 batches (iterations), 4 early-stop evaluations
            self.assertEqual(4, mocked_val_loss.call_count)

        # early-stopping criterion is True for first three calls (last one counts), thus 5 is the 'early_stop_index'
        self.assertEqual(5, early_stop_index)
        net.load_state_dict(early_stop_cp)  # check if the checkpoint can be loaded without errors

        self.assertIs(net is not None, True)
        np.testing.assert_array_less(np.zeros(4, dtype=float), train_loss_hist)  # check for positive entries
        np.testing.assert_array_less(np.zeros(4, dtype=float), val_loss_hist)
        np.testing.assert_array_less(np.zeros(4, dtype=float), val_acc_hist)
        np.testing.assert_array_less(np.zeros(4, dtype=float), test_acc_hist)
コード例 #2
0
    def test_compute_val_loss(self):
        """ Should calculate a positive validation loss. """
        # create net and setup trainer with fake-DataLoader (use the same loader for training, validation and test)
        net = generate_single_layer_net()
        fake_loader = generate_fake_data_loader()
        trainer = TrainerAdam(0., fake_loader, fake_loader, fake_loader)

        self.assertLessEqual(0.0, trainer.compute_val_loss(net))
コード例 #3
0
    def test_compute_val_acc(self):
        """ Should calculate the correct validation-accuracy.
        The fake-net with one linear layer classifies all fake-samples correctly.
        Use a fake-val_loader with one batch to validate the result. """
        # create net and setup trainer and fake-DataLoader with one batch
        net = generate_single_layer_net()
        val_loader = [[torch.tensor([[2., 2., 2., 2.], [2., 2., 0., 0.], [0., 0., 2., 2.]]), torch.tensor([0, 0, 1])]]
        trainer = TrainerAdam(0., [], val_loader, test_loader=[])

        self.assertEqual(1., trainer.compute_acc(net, test=False))
コード例 #4
0
    def test_compute_test_acc(self):
        """ Should calculate the correct test-accuracy.
        The fake-net with one linear layer classifies half of the fake-samples correctly.
        Use a fake-val_loader with one batch to validate the result. """
        # create net and setup trainer and fake-DataLoader with two batches (use the same samples for both batches)
        net = generate_single_layer_net()
        samples = torch.tensor([[2., 2., 2., 2.], [2., 2., 0., 0.], [0., 0., 2., 2.]])
        test_loader = [[samples, torch.tensor([0, 0, 1])], [samples, torch.tensor([1, 1, 0])]]
        trainer = TrainerAdam(0., [], [], test_loader=test_loader)

        self.assertEqual(0.5, trainer.compute_acc(net, test=True))
コード例 #5
0
    def test_execute_training(self):
        """ The training should be executed without errors and results should have correct shapes.
        Use a simple net with one linear layer and fake-data_loaders with samples of shape (1,4). """
        # create net and setup trainer with fake-DataLoader (use the same loader for training, validation and test)
        net = generate_single_layer_net()
        fake_loader = generate_fake_data_loader()
        trainer = TrainerAdam(0., fake_loader, fake_loader, fake_loader)

        net, train_loss_hist, val_loss_hist, val_acc_hist, test_acc_hist, _, _ = \
            trainer.train_net(net, epoch_count=2, plot_step=4)

        self.assertIs(net is not None, True)
        np.testing.assert_array_less(np.zeros(2, dtype=float), train_loss_hist)  # check for positive entries
        np.testing.assert_array_less(np.zeros(2, dtype=float), val_loss_hist)
        np.testing.assert_array_less(np.zeros(2, dtype=float), val_acc_hist)
        np.testing.assert_array_less(np.zeros(2, dtype=float), test_acc_hist)
コード例 #6
0
    def load_data_and_setup_trainer(self):
        """ Load dataset and initialize trainer.
        Store the length of the training-loader into 'self.epoch_length' to initialize histories. """
        # load dataset
        if self.specs.dataset == DatasetNames.MNIST:
            train_ld, val_ld, test_ld = get_mnist_data_loaders(device=self.specs.device, verb=self.specs.verbosity)
        elif self.specs.dataset == DatasetNames.CIFAR10:
            train_ld, val_ld, test_ld = get_cifar10_data_loaders(device=self.specs.device, verb=self.specs.verbosity)
        else:
            raise AssertionError(f"Could not load datasets, because the given name {self.specs.dataset} is invalid.")

        self.epoch_length = len(train_ld)
        self.trainer = TrainerAdam(self.specs.learning_rate, train_ld, val_ld, test_ld, self.device,
                                   self.specs.save_early_stop, self.specs.verbosity)
コード例 #7
0
 def test_should_save_early_stop_checkpoint_new_checkpoint(self):
     """ Should return True, because the validation-accuracy reached a new minimum. """
     trainer = TrainerAdam(0., [], [], [], save_early_stop=True)
     self.assertIs(trainer.should_save_early_stop_checkpoint(0.1, 0.2), True)
コード例 #8
0
 def test_should_save_early_stop_checkpoint_no_new_minimum_equal(self):
     """ Should return False, because the current validation-loss is equal to the the minimum. """
     trainer = TrainerAdam(0., [], [], [], save_early_stop=True)
     self.assertIs(trainer.should_save_early_stop_checkpoint(0.2, 0.2), False)
コード例 #9
0
 def test_should_save_early_stop_checkpoint_no_evaluation(self):
     """ Should return False, because the early-stopping criterion should not be evaluated. """
     trainer = TrainerAdam(0., [], [], [], save_early_stop=False)
     self.assertIs(trainer.should_save_early_stop_checkpoint(0.5, 0.2), False)