コード例 #1
0
def test_params():
    defaults = {
        "alpha": "float",
        "l1_ratio": {"type": "float", "default": 0.1},
        "l2_ratio": {"type": "float", "default": 0.0003},
        "random_str": {"type": "string", "default": "hello"},
    }
    entry_point = EntryPoint("entry_point_name", defaults, "command_name script.py")

    user1 = {}
    with pytest.raises(ExecutionException):
        entry_point._validate_parameters(user1)

    user_2 = {"beta": 0.004}
    with pytest.raises(ExecutionException):
        entry_point._validate_parameters(user_2)

    user_3 = {"alpha": 0.004, "gamma": 0.89}
    expected_final_3 = {"alpha": '0.004', "l1_ratio": '0.1', "l2_ratio": '0.0003',
                        "random_str": "hello"}
    expected_extra_3 = {"gamma": "0.89"}
    final_3, extra_3 = entry_point.compute_parameters(user_3, None)
    assert expected_extra_3 == extra_3
    assert expected_final_3 == final_3

    user_4 = {"alpha": 0.004, "l1_ratio": 0.0008, "random_str_2": "hello"}
    expected_final_4 = {"alpha": '0.004', "l1_ratio": '0.0008', "l2_ratio": '0.0003',
                        "random_str": "hello"}
    expected_extra_4 = {"random_str_2": "hello"}
    final_4, extra_4 = entry_point.compute_parameters(user_4, None)
    assert expected_extra_4 == extra_4
    assert expected_final_4 == final_4

    user_5 = {"alpha": -0.99, "random_str": "hi"}
    expected_final_5 = {"alpha": '-0.99', "l1_ratio": '0.1', "l2_ratio": '0.0003',
                        "random_str": "hi"}
    expected_extra_5 = {}
    final_5, extra_5 = entry_point.compute_parameters(user_5, None)
    assert expected_final_5 == final_5
    assert expected_extra_5 == extra_5

    user_6 = {"alpha": 0.77, "ALPHA": 0.89}
    expected_final_6 = {"alpha": '0.77', "l1_ratio": '0.1', "l2_ratio": '0.0003',
                        "random_str": "hello"}
    expected_extra_6 = {"ALPHA": "0.89"}
    final_6, extra_6 = entry_point.compute_parameters(user_6, None)
    assert expected_extra_6 == extra_6
    assert expected_final_6 == final_6
コード例 #2
0
def test_path_params():
    with TempDir() as tmp:
        dest_path = tmp.path()
        data_file = "s3://path.test/resources/data_file.csv"
        defaults = {
            "constants": {
                "type": "uri",
                "default": "s3://path.test/b1"
            },
            "data": {
                "type": "path",
                "default": data_file
            }
        }
        entry_point = EntryPoint("entry_point_name", defaults,
                                 "command_name script.py")

        with mock.patch("mlflow.data.download_uri") as download_uri_mock:
            final_1, extra_1 = entry_point.compute_parameters({}, None)
            assert (final_1 == {
                "constants": "s3://path.test/b1",
                "data": data_file
            })
            assert (extra_1 == {})
            assert download_uri_mock.call_count == 0

        with mock.patch("mlflow.data.download_uri") as download_uri_mock:
            user_2 = {"alpha": 0.001, "constants": "s3://path.test/b_two"}
            final_2, extra_2 = entry_point.compute_parameters(user_2, None)
            assert (final_2 == {
                "constants": "s3://path.test/b_two",
                "data": data_file
            })
            assert (extra_2 == {"alpha": "0.001"})
            assert download_uri_mock.call_count == 0

        with mock.patch("mlflow.data.download_uri") as download_uri_mock:
            user_3 = {"alpha": 0.001}
            final_3, extra_3 = entry_point.compute_parameters(
                user_3, dest_path)
            assert (final_3 == {
                "constants": "s3://path.test/b1",
                "data": "%s/data_file.csv" % dest_path
            })
            assert (extra_3 == {"alpha": "0.001"})
            assert download_uri_mock.call_count == 1

        with mock.patch("mlflow.data.download_uri") as download_uri_mock:
            user_4 = {
                "data": "s3://another.example.test/data_stash/images.tgz"
            }
            final_4, extra_4 = entry_point.compute_parameters(
                user_4, dest_path)
            assert (final_4 == {
                "constants": "s3://path.test/b1",
                "data": "%s/images.tgz" % dest_path
            })
            assert (extra_4 == {})
            assert download_uri_mock.call_count == 1