Exemplo n.º 1
0
def test_hyperparameter_tuning_objective():
    obj = hpo_job.HyperparameterTuningObjective(
        objective_type=hpo_job.HyperparameterTuningObjectiveType.MAXIMIZE, metric_name="test_metric"
    )
    obj2 = hpo_job.HyperparameterTuningObjective.from_flyte_idl(obj.to_flyte_idl())

    assert obj == obj2
Exemplo n.º 2
0
def test_hyperparameter_job_config():
    jc = hpo_job.HyperparameterTuningJobConfig(
        hyperparameter_ranges=parameter_ranges.ParameterRanges(
            parameter_range_map={
                "a":
                parameter_ranges.CategoricalParameterRange(values=["1", "2"]),
                "b":
                parameter_ranges.IntegerParameterRange(
                    min_value=0,
                    max_value=10,
                    scaling_type=parameter_ranges.HyperparameterScalingType.
                    LINEAR),
            }),
        tuning_strategy=hpo_job.HyperparameterTuningStrategy.BAYESIAN,
        tuning_objective=hpo_job.HyperparameterTuningObjective(
            objective_type=hpo_job.HyperparameterTuningObjectiveType.MAXIMIZE,
            metric_name="test_metric"),
        training_job_early_stopping_type=hpo_job.TrainingJobEarlyStoppingType.
        AUTO,
    )

    jc2 = hpo_job.HyperparameterTuningJobConfig.from_flyte_idl(
        jc.to_flyte_idl())
    assert jc2.hyperparameter_ranges == jc.hyperparameter_ranges
    assert jc2.tuning_strategy == jc.tuning_strategy
    assert jc2.tuning_objective == jc.tuning_objective
    assert jc2.training_job_early_stopping_type == jc.training_job_early_stopping_type
Exemplo n.º 3
0
def test_hyperparameter_job_config():
    jc = hpo_job.HyperparameterTuningJobConfig(
        tuning_strategy=hpo_job.HyperparameterTuningStrategy.BAYESIAN,
        tuning_objective=hpo_job.HyperparameterTuningObjective(
            objective_type=hpo_job.HyperparameterTuningObjectiveType.MAXIMIZE, metric_name="test_metric"
        ),
        training_job_early_stopping_type=hpo_job.TrainingJobEarlyStoppingType.AUTO,
    )

    jc2 = hpo_job.HyperparameterTuningJobConfig.from_flyte_idl(jc.to_flyte_idl())
    assert jc2.tuning_strategy == jc.tuning_strategy
    assert jc2.tuning_objective == jc.tuning_objective
    assert jc2.training_job_early_stopping_type == jc.training_job_early_stopping_type