def test_multitrial_experiment(pytestconfig): base_model = Net() evaluator = get_mnist_evaluator() search_strategy = strategy.Random() exp = RetiariiExperiment(base_model, evaluator, strategy=search_strategy) exp_config = RetiariiExeConfig('local') exp_config.trial_concurrency = 1 exp_config.max_trial_number = 1 exp_config._trial_command_params = nas_experiment_trial_params(pytestconfig.rootpath) exp.run(exp_config) ensure_success(exp) assert isinstance(exp.export_top_models()[0], dict) exp.stop()
def test_multi_trial(model, pytestconfig): evaluator_kwargs = { 'max_epochs': 1 } base_model, evaluator = _mnist_net(model, evaluator_kwargs) search_strategy = strategy.Random() exp = RetiariiExperiment(base_model, evaluator, strategy=search_strategy) exp_config = RetiariiExeConfig('local') exp_config.experiment_name = 'mnist_unittest' exp_config.trial_concurrency = 1 exp_config.max_trial_number = 1 exp_config._trial_command_params = nas_experiment_trial_params(pytestconfig.rootpath) exp.run(exp_config) ensure_success(exp) assert isinstance(exp.export_top_models()[0], dict) exp.stop()