def test_get_specs_for_conv6_cifar10(self): """ Should get specs without errors and verify the most important attributes. """ experiment_specs = get_specs(ExperimentPresetNames.CONV6_CIFAR10) self.assertIsInstance(experiment_specs, ExperimentSpecs) self.assertIs(experiment_specs.net, NetNames.CONV) self.assertIs(experiment_specs.dataset, DatasetNames.CIFAR10)
def test_get_specs_for_lenet_mnist(self): """ Should get specs without errors and verify the most important attributes. """ experiment_specs = get_specs(ExperimentPresetNames.LENET_MNIST) self.assertIsInstance(experiment_specs, ExperimentSpecs) self.assertIs(experiment_specs.net, NetNames.LENET) self.assertIs(experiment_specs.dataset, DatasetNames.MNIST)
def test_get_specs_for_tiny_cifar10(self): """ Should get specs without errors and verify the most important attributes. """ experiment_specs = get_specs(ExperimentPresetNames.TINY_CIFAR10) self.assertIsInstance(experiment_specs, ExperimentSpecs) self.assertIs(experiment_specs.net, NetNames.CONV) self.assertIs(experiment_specs.dataset, DatasetNames.CIFAR10) self.assertEqual([8, 8, 'M', 8, 'M'], experiment_specs.plan_conv) self.assertEqual([300, 100], experiment_specs.plan_fc)
def setup_pruning(args): """ Setup IMP- or OSP-experiment and print the specs or execute it. """ assert args.verbose in VerbosityLevel.__members__.values() specs = get_specs(args.experiment_preset) specs.device, specs.device_name = setup_cuda(args.cuda) log_from_medium(args.verbose, specs.device_name) if should_override_arg_positive_int(args.epochs, 'Epoch count'): specs.epoch_count = args.epochs if should_override_arg_positive_int(args.nets, 'Net count'): specs.net_count = args.nets if should_override_arg_positive_int(args.prunes, 'Prune count'): specs.prune_count = args.prunes if should_override_arg_rate(args.learn_rate, 'Learning-rate'): specs.learning_rate = args.learn_rate if should_override_arg_rate(args.prune_rate_conv, 'Pruning-rate Conv'): specs.prune_rate_conv = args.prune_rate_conv if should_override_arg_rate(args.prune_rate_fc, 'Pruning-rate FC'): specs.prune_rate_fc = args.prune_rate_fc if should_override_arg_plan(args.plan_conv, 'Convolutional plan'): specs.plan_conv = args.plan_conv if should_override_arg_plan(args.plan_fc, 'Fully connected plan'): specs.plan_fc = args.plan_fc if should_override_arg_positive_int(args.plot_step, 'Plot-step'): specs.plot_step = args.plot_step specs.verbosity = VerbosityLevel(args.verbose) specs.save_early_stop = args.early_stop specs.experiment_name = args.experiment_name if args.listing: print(specs) elif args.experiment_name == ExperimentNames.IMP: experiment = ExperimentIMP(specs) experiment.run_experiment() elif args.experiment_name == ExperimentNames.OSP: experiment = ExperimentOSP(specs) experiment.run_experiment()
def test_get_specs_should_raise_assertion_error_on_invalid_name(self): """ Should raise an assertion error, because the given name is invalid. """ with self.assertRaises(AssertionError): get_specs("Invalid experiment name")