def test_meta_schedule_xgb_model(): extractor = RandomFeatureExtractor() model = XGBModel(extractor=extractor, num_warmup_samples=2) update_sample_count = 10 predict_sample_count = 100 model.update( TuneContext(), [_dummy_candidate() for i in range(update_sample_count)], [_dummy_result() for i in range(update_sample_count)], ) model.predict(TuneContext(), [_dummy_candidate() for i in range(predict_sample_count)])
def test_meta_schedule_xgb_model_reload(): extractor = RandomFeatureExtractor() model = XGBModel(extractor=extractor, num_warmup_samples=10) update_sample_count = 20 predict_sample_count = 30 model.update( TuneContext(), [_dummy_candidate() for i in range(update_sample_count)], [_dummy_result() for i in range(update_sample_count)], ) model.predict(TuneContext(), [_dummy_candidate() for i in range(predict_sample_count)]) with tempfile.NamedTemporaryFile() as path: # Backup random_state = model.extractor.random_state # save feature extractor's random state old_data = model.data old_data_size = model.data_size model.save(path.name) res1 = model.predict( TuneContext(), [_dummy_candidate() for i in range(predict_sample_count)]) # Load model.extractor.random_state = random_state # load feature extractor's random state model.load(path.name) new_data = model.data new_data_size = model.data_size res2 = model.predict( TuneContext(), [_dummy_candidate() for i in range(predict_sample_count)]) assert (res1 == res2).all() assert old_data_size == new_data_size assert len(old_data) == len(new_data) for (k1, g1), (k2, g2) in zip(old_data.items(), new_data.items()): assert k1 == k2 assert k1 == g1.group_hash assert k2 == g2.group_hash assert (g1.costs == g2.costs).all() assert len(g1.features) == len(g2.features) for f1, f2 in zip(g1.features, g2.features): assert (f1 == f2).all()
def test_meta_schedule_xgb_model_reload(): extractor = RandomFeatureExtractor() model = XGBModel(extractor=extractor, num_warmup_samples=10) update_sample_count = 20 predict_sample_count = 30 model.update( TuneContext(), [_dummy_candidate() for i in range(update_sample_count)], [_dummy_result() for i in range(update_sample_count)], ) model.predict(TuneContext(), [_dummy_candidate() for i in range(predict_sample_count)]) random_state = model.extractor.random_state # save feature extractor's random state path = os.path.join(tempfile.mkdtemp(), "test_output_meta_schedule_xgb_model.bin") cached = (model.cached_features.copy(), model.cached_mean_costs.copy()) model.save(path) res1 = model.predict( TuneContext(), [_dummy_candidate() for i in range(predict_sample_count)]) model.extractor.random_state = random_state # load feature extractor's random state model.cached_features = None model.cached_mean_costs = None model.load(path) new_cached = (model.cached_features.copy(), model.cached_mean_costs.copy()) res2 = model.predict( TuneContext(), [_dummy_candidate() for i in range(predict_sample_count)]) shutil.rmtree(os.path.dirname(path)) assert (res1 == res2).all() # cached feature does not change assert len(cached[0]) == len(new_cached[0]) for i in range(len(cached[0])): assert (cached[0][i] == new_cached[0][i]).all() # cached meaen cost does not change assert (cached[1] == new_cached[1]).all()