Пример #1
0
 def test_should_start_experiment_with_early_stop(self):
     """ Playground should start the experiment with flag for early-stopping-checkpoints during training. """
     expected_specs = get_specs_lenet_mnist()
     expected_specs.save_early_stop = True
     with mock.patch('playground.ExperimentIMP') as mocked_experiment:
         playground_main([ExperimentNames.IMP, ExperimentPresetNames.LENET_MNIST, '-es'])
         mocked_experiment.assert_called_once_with(expected_specs)
Пример #2
0
 def test_should_start_experiment_with_modified_plot_step_parameter(self):
     """ Playground should start the experiment with modified plot_step. """
     expected_specs = get_specs_lenet_mnist()
     expected_specs.plot_step = 42
     with mock.patch('playground.ExperimentIMP') as mocked_experiment:
         playground_main([ExperimentNames.IMP, ExperimentPresetNames.LENET_MNIST, '-ps', '42'])
         mocked_experiment.assert_called_once_with(expected_specs)
Пример #3
0
 def test_should_start_experiment_with_modified_prune_rate_fc_parameter(self):
     """ Playground should start the experiment with modified prune_rate_fc. """
     expected_specs = get_specs_lenet_mnist()
     expected_specs.prune_rate_fc = 0.5
     with mock.patch('playground.ExperimentIMP') as mocked_experiment:
         playground_main([ExperimentNames.IMP, ExperimentPresetNames.LENET_MNIST, '-prf', '0.5'])
         mocked_experiment.assert_called_once_with(expected_specs)
Пример #4
0
 def test_should_start_experiment_osp(self):
     """ Playground should start the OSP-experiment with correct standard specs. """
     expected_specs = get_specs_lenet_mnist()
     expected_specs.experiment_name = ExperimentNames.OSP
     with mock.patch('playground.ExperimentOSP') as mocked_experiment:
         playground_main([ExperimentNames.OSP, ExperimentPresetNames.LENET_MNIST])
         mocked_experiment.assert_called_once_with(expected_specs)
Пример #5
0
    def test_should_start_experiment_with_detailed_logging(self):
        """ Playground should start the experiment with detailed logging. """
        expected_specs = get_specs_lenet_mnist()
        expected_specs.verbosity = VerbosityLevel.DETAILED
        with mock.patch('playground.ExperimentIMP') as mocked_experiment:
            with StringIO() as interception:
                old_stdout = sys.stdout
                sys.stdout = interception

                playground_main([ExperimentNames.IMP, ExperimentPresetNames.LENET_MNIST, '-vv'])

                sys.stdout = old_stdout
                self.assertEqual("cpu\n", interception.getvalue())
                mocked_experiment.assert_called_once_with(expected_specs)
Пример #6
0
    def test_should_print_experiment_specs_osp(self):
        """ Playground should not start the OSP-experiment and print the specs. """
        expected_specs = get_specs_lenet_mnist()
        expected_specs.experiment_name = ExperimentNames.OSP
        with mock.patch('playground.ExperimentOSP') as mocked_experiment:
            with StringIO() as interception:
                old_stdout = sys.stdout
                sys.stdout = interception

                playground_main([ExperimentNames.OSP, ExperimentPresetNames.LENET_MNIST, '-l'])

                sys.stdout = old_stdout
                self.assertEqual(f"{expected_specs}\n", interception.getvalue())
                mocked_experiment.assert_not_called()
Пример #7
0
    def test_generate_randomly_reinitialized_net(self):
        """ Should generate a network with equal masks but different weights. """
        specs = experiment_specs.get_specs_lenet_mnist()
        specs.save_early_stop = True
        torch.manual_seed(0)
        net = Net(specs.net, specs.dataset, specs.plan_conv, specs.plan_fc)

        torch.manual_seed(1)
        new_net = ExperimentRandomRetrain.generate_randomly_reinitialized_net(
            specs, net.state_dict())

        self.assertIs(net.fc[0].weight.eq(new_net.fc[0].weight).all().item(),
                      False)
        self.assertIs(
            net.fc[0].weight_mask.eq(new_net.fc[0].weight_mask).all().item(),
            True)
        self.assertIs(
            net.out.weight.eq(new_net.out.weight).all().item(), False)
        self.assertIs(
            net.out.weight_mask.eq(new_net.out.weight_mask).all().item(), True)