def test_meta_schedule_xgb_model_reupdate():
    extractor = RandomFeatureExtractor()
    model = XGBModel(extractor=extractor, num_warmup_samples=2)
    update_sample_count = 60
    predict_sample_count = 100
    model.update(
        TuneContext(),
        [_dummy_candidate() for i in range(update_sample_count)],
        [_dummy_result() for i in range(update_sample_count)],
    )
    model.update(
        TuneContext(),
        [_dummy_candidate() for i in range(update_sample_count)],
        [_dummy_result() for i in range(update_sample_count)],
    )
    model.update(
        TuneContext(),
        [_dummy_candidate() for i in range(update_sample_count)],
        [_dummy_result() for i in range(update_sample_count)],
    )
    model.predict(TuneContext(), [_dummy_candidate() for i in range(predict_sample_count)])
def test_meta_schedule_xgb_model_reload():
    extractor = RandomFeatureExtractor()
    model = XGBModel(extractor=extractor, num_warmup_samples=10)
    update_sample_count = 20
    predict_sample_count = 30
    model.update(
        TuneContext(),
        [_dummy_candidate() for i in range(update_sample_count)],
        [_dummy_result() for i in range(update_sample_count)],
    )
    model.predict(TuneContext(), [_dummy_candidate() for i in range(predict_sample_count)])
    with tempfile.NamedTemporaryFile() as path:
        # Backup
        random_state = model.extractor.random_state  # save feature extractor's random state
        old_data = model.data
        old_data_size = model.data_size
        model.save(path.name)
        res1 = model.predict(
            TuneContext(), [_dummy_candidate() for i in range(predict_sample_count)]
        )
        # Load
        model.extractor.random_state = random_state  # load feature extractor's random state
        model.load(path.name)
        new_data = model.data
        new_data_size = model.data_size
        res2 = model.predict(
            TuneContext(), [_dummy_candidate() for i in range(predict_sample_count)]
        )
    assert (res1 == res2).all()
    assert old_data_size == new_data_size
    assert len(old_data) == len(new_data)
    for (k1, g1), (k2, g2) in zip(old_data.items(), new_data.items()):
        assert k1 == k2
        assert k1 == g1.group_hash
        assert k2 == g2.group_hash
        assert (g1.costs == g2.costs).all()
        assert len(g1.features) == len(g2.features)
        for f1, f2 in zip(g1.features, g2.features):
            assert (f1 == f2).all()
Ejemplo n.º 3
0
def test_meta_schedule_xgb_model_reload():
    extractor = RandomFeatureExtractor()
    model = XGBModel(extractor=extractor, num_warmup_samples=10)
    update_sample_count = 20
    predict_sample_count = 30
    model.update(
        TuneContext(),
        [_dummy_candidate() for i in range(update_sample_count)],
        [_dummy_result() for i in range(update_sample_count)],
    )
    model.predict(TuneContext(),
                  [_dummy_candidate() for i in range(predict_sample_count)])
    random_state = model.extractor.random_state  # save feature extractor's random state
    path = os.path.join(tempfile.mkdtemp(),
                        "test_output_meta_schedule_xgb_model.bin")
    cached = (model.cached_features.copy(), model.cached_mean_costs.copy())
    model.save(path)
    res1 = model.predict(
        TuneContext(),
        [_dummy_candidate() for i in range(predict_sample_count)])
    model.extractor.random_state = random_state  # load feature extractor's random state
    model.cached_features = None
    model.cached_mean_costs = None
    model.load(path)
    new_cached = (model.cached_features.copy(), model.cached_mean_costs.copy())
    res2 = model.predict(
        TuneContext(),
        [_dummy_candidate() for i in range(predict_sample_count)])
    shutil.rmtree(os.path.dirname(path))
    assert (res1 == res2).all()
    # cached feature does not change
    assert len(cached[0]) == len(new_cached[0])
    for i in range(len(cached[0])):
        assert (cached[0][i] == new_cached[0][i]).all()
    # cached meaen cost does not change
    assert (cached[1] == new_cached[1]).all()