Exemplo n.º 1
0
def test_gbt_estimator() -> None:
    config = conf.load_config(
        conf.decision_trees_examples_path("gbt_titanic_estimator/const.yaml"))
    config = conf.set_max_length(config, {"batches": 200})

    exp.run_basic_test_with_temp_config(
        config, conf.decision_trees_examples_path("gbt_titanic_estimator"), 1)
Exemplo n.º 2
0
def test_gbt_titanic_estimator_accuracy() -> None:
    config = conf.load_config(
        conf.decision_trees_examples_path("gbt_titanic_estimator/const.yaml"))
    experiment_id = exp.run_basic_test_with_temp_config(
        config, conf.decision_trees_examples_path("gbt_titanic_estimator"), 1)

    trials = exp.experiment_trials(experiment_id)
    trial_metrics = exp.trial_metrics(trials[0]["id"])

    validation_accuracies = [
        step["validation"]["metrics"]["validation_metrics"]["accuracy"]
        for step in trial_metrics["steps"] if step.get("validation")
    ]

    target_accuracy = 0.74
    assert max(validation_accuracies) > target_accuracy, (
        "gbt_titanic_estimator did not reach minimum target accuracy {} in {} steps."
        " full validation accuracy history: {}".format(
            target_accuracy, len(trial_metrics["steps"]),
            validation_accuracies))