Beispiel #1
0
    def test_get_specs_for_conv6_cifar10(self):
        """ Should get specs without errors and verify the most important attributes. """
        experiment_specs = get_specs(ExperimentPresetNames.CONV6_CIFAR10)

        self.assertIsInstance(experiment_specs, ExperimentSpecs)
        self.assertIs(experiment_specs.net, NetNames.CONV)
        self.assertIs(experiment_specs.dataset, DatasetNames.CIFAR10)
Beispiel #2
0
    def test_get_specs_for_lenet_mnist(self):
        """ Should get specs without errors and verify the most important attributes. """
        experiment_specs = get_specs(ExperimentPresetNames.LENET_MNIST)

        self.assertIsInstance(experiment_specs, ExperimentSpecs)
        self.assertIs(experiment_specs.net, NetNames.LENET)
        self.assertIs(experiment_specs.dataset, DatasetNames.MNIST)
Beispiel #3
0
    def test_get_specs_for_tiny_cifar10(self):
        """ Should get specs without errors and verify the most important attributes. """
        experiment_specs = get_specs(ExperimentPresetNames.TINY_CIFAR10)

        self.assertIsInstance(experiment_specs, ExperimentSpecs)
        self.assertIs(experiment_specs.net, NetNames.CONV)
        self.assertIs(experiment_specs.dataset, DatasetNames.CIFAR10)
        self.assertEqual([8, 8, 'M', 8, 'M'], experiment_specs.plan_conv)
        self.assertEqual([300, 100], experiment_specs.plan_fc)
Beispiel #4
0
def setup_pruning(args):
    """ Setup IMP- or OSP-experiment and print the specs or execute it. """
    assert args.verbose in VerbosityLevel.__members__.values()

    specs = get_specs(args.experiment_preset)

    specs.device, specs.device_name = setup_cuda(args.cuda)
    log_from_medium(args.verbose, specs.device_name)

    if should_override_arg_positive_int(args.epochs, 'Epoch count'):
        specs.epoch_count = args.epochs
    if should_override_arg_positive_int(args.nets, 'Net count'):
        specs.net_count = args.nets
    if should_override_arg_positive_int(args.prunes, 'Prune count'):
        specs.prune_count = args.prunes
    if should_override_arg_rate(args.learn_rate, 'Learning-rate'):
        specs.learning_rate = args.learn_rate
    if should_override_arg_rate(args.prune_rate_conv, 'Pruning-rate Conv'):
        specs.prune_rate_conv = args.prune_rate_conv
    if should_override_arg_rate(args.prune_rate_fc, 'Pruning-rate FC'):
        specs.prune_rate_fc = args.prune_rate_fc
    if should_override_arg_plan(args.plan_conv, 'Convolutional plan'):
        specs.plan_conv = args.plan_conv
    if should_override_arg_plan(args.plan_fc, 'Fully connected plan'):
        specs.plan_fc = args.plan_fc
    if should_override_arg_positive_int(args.plot_step, 'Plot-step'):
        specs.plot_step = args.plot_step
    specs.verbosity = VerbosityLevel(args.verbose)
    specs.save_early_stop = args.early_stop
    specs.experiment_name = args.experiment_name

    if args.listing:
        print(specs)
    elif args.experiment_name == ExperimentNames.IMP:
        experiment = ExperimentIMP(specs)
        experiment.run_experiment()
    elif args.experiment_name == ExperimentNames.OSP:
        experiment = ExperimentOSP(specs)
        experiment.run_experiment()
Beispiel #5
0
 def test_get_specs_should_raise_assertion_error_on_invalid_name(self):
     """ Should raise an assertion error, because the given name is invalid. """
     with self.assertRaises(AssertionError):
         get_specs("Invalid experiment name")