def test_cache_files(self): auto_var = AutoVar(logging_level=logging.INFO) auto_var.add_variable_class(DatasetVarClass()) auto_var.add_variable_class(OrdVarClass()) auto_var.set_variable_value("dataset", "no4_halfmoon_5") auto_var.set_variable_value("ord", "1") X, y = auto_var.get_var("dataset") cacheX, cachey = auto_var.get_var("dataset") assert_array_equal(X, cacheX) assert_array_equal(y, cachey) temp_dir = auto_var.variables['dataset']['cache_dirs'][ 'no4_halfmoon_(?P<n_samples>\\d+)'] cacheX, cachey = joblib.load( os.path.join(temp_dir, "no4_halfmoon_5-1.pkl")) assert_array_equal(X, cacheX) assert_array_equal(y, cachey)
def test_json_file(self): settings = {'file_format': 'json', 'result_file_dir': 'test'} mkdir_p(settings['result_file_dir']) auto_var = AutoVar( settings=settings, after_experiment_hooks=[ partial(save_result_to_file, get_name_fn=default_get_file_name) ], ) auto_var.add_variable_class(OrdVarClass()) auto_var.set_variable_value_by_dict({'ord': '1'}) def experiment(auto_var): return {'test': auto_var.get_var('ord')} _ = auto_var.run_single_experiment(experiment, with_hook=True) with open("test/1.json", 'r') as f: ret = json.load(f) self.assertEqual(ret['test'], auto_var.get_var('ord')) shutil.rmtree(settings['result_file_dir'])
def test_val(self): auto_var = AutoVar(logging_level=logging.INFO) with self.assertRaises(VariableNotRegisteredError): auto_var.get_var('ord') auto_var.add_variable_class(OrdVarClass()) auto_var.add_variable_class(DatasetVarClass()) auto_var.add_variable('random_seed', int) with self.assertRaises(VariableValueNotSetError): auto_var.get_var('ord') auto_var.set_variable_value_by_dict({ 'ord': '1', 'dataset': 'halfmoon_200', 'random_seed': 1126 }) self.assertEqual(auto_var.get_var('ord'), 1) self.assertEqual(auto_var.get_var('random_seed'), 1126) self.assertEqual(len(auto_var.get_var('dataset')[0]), 200) self.assertEqual( len(auto_var.get_var_with_argument('dataset', 'halfmoon_300')[0]), 300) with self.assertRaises(ValueError): auto_var.set_variable_value_by_dict({'ord': 'l2'}) with self.assertRaises(TypeError): auto_var.set_variable_value_by_dict({'random_seed': '1126.0'}) self.assertEqual(auto_var.get_var_shown_name(var_name="dataset"), 'shown_halfmoon') assert_array_equal( auto_var.get_var_with_argument('dataset', 'halfmoon_300')[0], auto_var.get_var_with_argument('dataset', 'moon_300')[0], ) argparse_help = auto_var.get_argparser().format_help() self.assertTrue('halfmoon dataset' in argparse_help) self.assertTrue('Dataset variable class' in argparse_help)
def test_cmd_args_with_vars_2(self): auto_var = AutoVar(logging_level=logging.INFO) auto_var.add_variable_class(OrdVarClass()) auto_var.parse_argparse(args=[]) self.assertEqual(auto_var.get_var("ord"), 2)