Example #1
0
def test_hyperparameter_job_config():
    jc = hpo_job.HyperparameterTuningJobConfig(
        hyperparameter_ranges=parameter_ranges.ParameterRanges(
            parameter_range_map={
                "a":
                parameter_ranges.CategoricalParameterRange(values=["1", "2"]),
                "b":
                parameter_ranges.IntegerParameterRange(
                    min_value=0,
                    max_value=10,
                    scaling_type=parameter_ranges.HyperparameterScalingType.
                    LINEAR),
            }),
        tuning_strategy=hpo_job.HyperparameterTuningStrategy.BAYESIAN,
        tuning_objective=hpo_job.HyperparameterTuningObjective(
            objective_type=hpo_job.HyperparameterTuningObjectiveType.MAXIMIZE,
            metric_name="test_metric"),
        training_job_early_stopping_type=hpo_job.TrainingJobEarlyStoppingType.
        AUTO,
    )

    jc2 = hpo_job.HyperparameterTuningJobConfig.from_flyte_idl(
        jc.to_flyte_idl())
    assert jc2.hyperparameter_ranges == jc.hyperparameter_ranges
    assert jc2.tuning_strategy == jc.tuning_strategy
    assert jc2.tuning_objective == jc.tuning_objective
    assert jc2.training_job_early_stopping_type == jc.training_job_early_stopping_type
def test_parameter_ranges():
    pr = parameter_ranges.ParameterRanges(
        {
            "a": parameter_ranges.CategoricalParameterRange(values=["a-1", "a-2"]),
            "b": parameter_ranges.IntegerParameterRange(
                min_value=1, max_value=5, scaling_type=parameter_ranges.HyperparameterScalingType.LINEAR
            ),
            "c": parameter_ranges.ContinuousParameterRange(
                min_value=0.1, max_value=1.0, scaling_type=parameter_ranges.HyperparameterScalingType.LOGARITHMIC
            ),
        },
    )
    pr2 = parameter_ranges.ParameterRanges.from_flyte_idl(pr.to_flyte_idl())
    assert pr == pr2
Example #3
0
        min_value=0,
        scaling_type=parameter_ranges.HyperparameterScalingType.LOGARITHMIC)

    pr2 = parameter_ranges.IntegerParameterRange.from_flyte_idl(
        pr.to_flyte_idl())
    assert pr == pr2
    assert type(pr2.max_value) == int
    assert type(pr2.min_value) == int
    assert pr2.max_value == 1
    assert pr2.min_value == 0
    assert pr2.scaling_type == parameter_ranges.HyperparameterScalingType.LOGARITHMIC


def test_categorical_parameter_range():
    case = unittest.TestCase()
    pr = parameter_ranges.CategoricalParameterRange(values=["abc", "cat"])

    pr2 = parameter_ranges.CategoricalParameterRange.from_flyte_idl(
        pr.to_flyte_idl())
    assert pr == pr2
    assert isinstance(pr2.values, list)
    case.assertCountEqual(pr2.values, pr.values)


def test_parameter_ranges():
    pr = parameter_ranges.ParameterRanges(
        {
            "a":
            parameter_ranges.CategoricalParameterRange(values=["a-1", "a-2"]),
            "b":
            parameter_ranges.IntegerParameterRange(