def test_save_load_buffer(): """Cash controller can be saved/loaded onto a file-like object.""" w_controller = _metalearn_controller(_a_space()) fileobj = io.BytesIO() w_controller.save(fileobj) fileobj.seek(0) r_controller = MetaLearnController.load(fileobj) assert utils.models_are_equal(r_controller, w_controller)
def test_save_load_tempfile(): """Cash controller can be saved/loaded onto a file specified as path.""" w_controller = _metalearn_controller(_a_space()) with tempfile.TemporaryFile() as f: w_controller.save(f) f.seek(0) r_controller = MetaLearnController.load(f) assert utils.models_are_equal(r_controller, w_controller)
from pathlib import Path from metalearn.metalearn_controller import MetaLearnController from metalearn.inference.inference_engine import CASHInference from metalearn.task_environment import TaskEnvironment from metalearn.data_environments import sklearn_classification build_path = Path(os.path.dirname(__file__)) / ".." / "floyd_outputs" / "225" output_path = Path(os.path.dirname(__file__)) / "results" / \ "autosklearn_benchmark_pretrained_sklearn_agent" results_path = output_path / "inference_results" results_path.mkdir(exist_ok=True) controller = MetaLearnController.load(build_path / "controller_trial_0.pt") experiment_results = pd.read_csv(build_path / "rnn_metalearn_controller_experiment.csv") base_mlf_path = build_path / "metalearn_controller_mlfs_trial_0" # get top 10 best mlfs for each data env across all episodes. best_mlf_episodes = (experiment_results.groupby("data_env_names").apply( lambda df: (df.sort_values("best_validation_scores", ascending=False).head( 10)))["episode"].reset_index(level=1, drop=True)) # a dict mapping datasets to the top 10 mlfs found for those datasets. best_mlfs = (best_mlf_episodes.map(lambda x: joblib.load(base_mlf_path / ( "best_mlf_episode_%d.pkl" % x))).groupby("data_env_names").apply( lambda x: list(x)).to_dict()) sklearn_data_envs = sklearn_classification.envs()
def load_controller(): return MetaLearnController.load(MODEL_PATH / "controller_trial_0.pt")