示例#1
0
 def test_log_batch_nonexistent_run(self):
     with self.assertRaises(MlflowException) as e:
         self.store.log_batch("bad-run-uuid", [], [], [])
     assert e.exception.error_code == ErrorCode.Name(RESOURCE_DOES_NOT_EXIST)
     assert "Run with id=bad-run-uuid not found" in e.exception.message
    def test_search_registered_models(self):
        # create some registered models
        prefix = "test_for_search_"
        names = [
            prefix + name
            for name in ["RM1", "RM2", "RM3", "RM4", "RM4A", "RM4a"]
        ]
        [self._rm_maker(name) for name in names]

        # search with no filter should return all registered models
        rms, _ = self._search_registered_models(None)
        self.assertEqual(rms, names)

        # equality search using name should return exactly the 1 name
        rms, _ = self._search_registered_models(f"name='{names[0]}'")
        self.assertEqual(rms, [names[0]])

        # equality search using name that is not valid should return nothing
        rms, _ = self._search_registered_models(f"name='{names[0] + 'cats'}'")
        self.assertEqual(rms, [])

        # case-sensitive prefix search using LIKE should return all the RMs
        rms, _ = self._search_registered_models(f"name LIKE '{prefix}%'")
        self.assertEqual(rms, names)

        # case-sensitive prefix search using LIKE with surrounding % should return all the RMs
        rms, _ = self._search_registered_models(f"name LIKE '%RM%'")
        self.assertEqual(rms, names)

        # case-sensitive prefix search using LIKE with surrounding % should return all the RMs
        # _e% matches test_for_search_ , so all RMs should match
        rms, _ = self._search_registered_models(f"name LIKE '_e%'")
        self.assertEqual(rms, names)

        # case-sensitive prefix search using LIKE should return just rm4
        rms, _ = self._search_registered_models(
            f"name LIKE '{prefix + 'RM4A'}%'")
        self.assertEqual(rms, [names[4]])

        # case-sensitive prefix search using LIKE should return no models if no match
        rms, _ = self._search_registered_models(
            f"name LIKE '{prefix + 'cats'}%'")
        self.assertEqual(rms, [])

        # confirm that LIKE is not case-sensitive
        rms, _ = self._search_registered_models(f"name lIkE '%blah%'")
        self.assertEqual(rms, [])

        rms, _ = self._search_registered_models(
            f"name like '{prefix + 'RM4A'}%'")
        self.assertEqual(rms, [names[4]])

        # case-insensitive prefix search using ILIKE should return both rm5 and rm6
        rms, _ = self._search_registered_models(
            f"name ILIKE '{prefix + 'RM4A'}%'")
        self.assertEqual(rms, names[4:])

        # case-insensitive postfix search with ILIKE
        rms, _ = self._search_registered_models(f"name ILIKE '%RM4a'")
        self.assertEqual(rms, names[4:])

        # case-insensitive prefix search using ILIKE should return both rm5 and rm6
        rms, _ = self._search_registered_models(
            f"name ILIKE '{prefix + 'cats'}%'")
        self.assertEqual(rms, [])

        # confirm that ILIKE is not case-sensitive
        rms, _ = self._search_registered_models(f"name iLike '%blah%'")
        self.assertEqual(rms, [])

        rms, _ = self._search_registered_models(f"name ilike '%RM4a'")
        self.assertEqual(rms, names[4:])

        # cannot search by invalid comparator types
        with self.assertRaises(MlflowException) as exception_context:
            self._search_registered_models("name!=something")
        assert exception_context.exception.error_code == ErrorCode.Name(
            INVALID_PARAMETER_VALUE)

        # cannot search by run_id
        with self.assertRaises(MlflowException) as exception_context:
            self._search_registered_models("run_id='%s'" % "somerunID")
        assert exception_context.exception.error_code == ErrorCode.Name(
            INVALID_PARAMETER_VALUE)

        # cannot search by source_path
        with self.assertRaises(MlflowException) as exception_context:
            self._search_registered_models("source_path = 'A/D'")
        assert exception_context.exception.error_code == ErrorCode.Name(
            INVALID_PARAMETER_VALUE)

        # cannot search by other params
        with self.assertRaises(MlflowException) as exception_context:
            self._search_registered_models("evilhax = true")
        assert exception_context.exception.error_code == ErrorCode.Name(
            INVALID_PARAMETER_VALUE)

        # delete last registered model. search should not return the first 5
        self.store.delete_registered_model(name=names[-1])
        self.assertEqual(
            self._search_registered_models(None, max_results=1000),
            (names[:-1], None))

        # equality search using name should return no names
        self.assertEqual(self._search_registered_models(f"name='{names[-1]}'"),
                         ([], None))

        # case-sensitive prefix search using LIKE should return all the RMs
        self.assertEqual(
            self._search_registered_models(f"name LIKE '{prefix}%'"),
            (names[0:5], None))

        # case-insensitive prefix search using ILIKE should return both rm5 and rm6
        self.assertEqual(
            self._search_registered_models(f"name ILIKE '{prefix + 'RM4A'}%'"),
            ([names[4]], None))
示例#3
0
import json

from mlflow.protos.databricks_pb2 import INTERNAL_ERROR, TEMPORARILY_UNAVAILABLE, \
    ENDPOINT_NOT_FOUND, PERMISSION_DENIED, REQUEST_LIMIT_EXCEEDED, BAD_REQUEST, \
    INVALID_PARAMETER_VALUE, RESOURCE_DOES_NOT_EXIST, INVALID_STATE, RESOURCE_ALREADY_EXISTS, \
    ErrorCode

ERROR_CODE_TO_HTTP_STATUS = {
    ErrorCode.Name(INTERNAL_ERROR): 500,
    ErrorCode.Name(INVALID_STATE): 500,
    ErrorCode.Name(TEMPORARILY_UNAVAILABLE): 503,
    ErrorCode.Name(REQUEST_LIMIT_EXCEEDED): 429,
    ErrorCode.Name(ENDPOINT_NOT_FOUND): 404,
    ErrorCode.Name(RESOURCE_DOES_NOT_EXIST): 404,
    ErrorCode.Name(PERMISSION_DENIED): 403,
    ErrorCode.Name(BAD_REQUEST): 400,
    ErrorCode.Name(RESOURCE_ALREADY_EXISTS): 400,
    ErrorCode.Name(INVALID_PARAMETER_VALUE): 400
}


class MlflowException(Exception):
    """
    Generic exception thrown to surface failure information about external-facing operations.
    The error message associated with this exception may be exposed to clients in HTTP responses
    for debugging purposes. If the error text is sensitive, raise a generic `Exception` object
    instead.
    """
    def __init__(self, message, error_code=INTERNAL_ERROR, **kwargs):
        """
        :param message: The message describing the error that occured. This will be included in the
示例#4
0
    def test_get_experiment_by_name(self, store_class):
        creds = MlflowHostCreds("https://hello")
        store = store_class(lambda: creds)
        with mock.patch("mlflow.utils.rest_utils.http_request") as mock_http:
            response = mock.MagicMock
            response.status_code = 200
            experiment = Experiment(
                experiment_id="123",
                name="abc",
                artifact_location="/abc",
                lifecycle_stage=LifecycleStage.ACTIVE,
            )
            response.text = json.dumps(
                {"experiment": json.loads(message_to_json(experiment.to_proto()))}
            )
            mock_http.return_value = response
            result = store.get_experiment_by_name("abc")
            expected_message0 = GetExperimentByName(experiment_name="abc")
            self._verify_requests(
                mock_http,
                creds,
                "experiments/get-by-name",
                "GET",
                message_to_json(expected_message0),
            )
            assert result.experiment_id == experiment.experiment_id
            assert result.name == experiment.name
            assert result.artifact_location == experiment.artifact_location
            assert result.lifecycle_stage == experiment.lifecycle_stage
            # Test GetExperimentByName against nonexistent experiment
            mock_http.reset_mock()
            nonexistent_exp_response = mock.MagicMock
            nonexistent_exp_response.status_code = 404
            nonexistent_exp_response.text = MlflowException(
                "Exp doesn't exist!", RESOURCE_DOES_NOT_EXIST
            ).serialize_as_json()
            mock_http.return_value = nonexistent_exp_response
            assert store.get_experiment_by_name("nonexistent-experiment") is None
            expected_message1 = GetExperimentByName(experiment_name="nonexistent-experiment")
            self._verify_requests(
                mock_http,
                creds,
                "experiments/get-by-name",
                "GET",
                message_to_json(expected_message1),
            )
            assert mock_http.call_count == 1

            # Test REST client behavior against a mocked old server, which has handler for
            # ListExperiments but not GetExperimentByName
            mock_http.reset_mock()
            list_exp_response = mock.MagicMock
            list_exp_response.text = json.dumps(
                {"experiments": [json.loads(message_to_json(experiment.to_proto()))]}
            )
            list_exp_response.status_code = 200

            def response_fn(*args, **kwargs):
                # pylint: disable=unused-argument
                if kwargs.get("endpoint") == "/api/2.0/mlflow/experiments/get-by-name":
                    raise MlflowException(
                        "GetExperimentByName is not implemented", ENDPOINT_NOT_FOUND
                    )
                else:
                    return list_exp_response

            mock_http.side_effect = response_fn
            result = store.get_experiment_by_name("abc")
            expected_message2 = ListExperiments(view_type=ViewType.ALL)
            self._verify_requests(
                mock_http,
                creds,
                "experiments/get-by-name",
                "GET",
                message_to_json(expected_message0),
            )
            self._verify_requests(
                mock_http, creds, "experiments/list", "GET", message_to_json(expected_message2)
            )
            assert result.experiment_id == experiment.experiment_id
            assert result.name == experiment.name
            assert result.artifact_location == experiment.artifact_location
            assert result.lifecycle_stage == experiment.lifecycle_stage

            # Verify that REST client won't fall back to ListExperiments for 429 errors (hitting
            # rate limits)
            mock_http.reset_mock()

            def rate_limit_response_fn(*args, **kwargs):
                # pylint: disable=unused-argument
                raise MlflowException(
                    "Hit rate limit on GetExperimentByName", REQUEST_LIMIT_EXCEEDED
                )

            mock_http.side_effect = rate_limit_response_fn
            with pytest.raises(MlflowException) as exc_info:
                store.get_experiment_by_name("imspamming")
            assert exc_info.value.error_code == ErrorCode.Name(REQUEST_LIMIT_EXCEEDED)
            assert mock_http.call_count == 1
示例#5
0
def register_model(model_uri, name, await_registration_for=DEFAULT_AWAIT_MAX_SLEEP_SECONDS):
    """
    Create a new model version in model registry for the model files specified by ``model_uri``.
    Note that this method assumes the model registry backend URI is the same as that of the
    tracking backend.

    :param model_uri: URI referring to the MLmodel directory. Use a ``runs:/`` URI if you want to
                      record the run ID with the model in model registry. ``models:/`` URIs are
                      currently not supported.
    :param name: Name of the registered model under which to create a new model version. If a
                 registered model with the given name does not exist, it will be created
                 automatically.
    :param await_registration_for: Number of seconds to wait for the model version to finish
                            being created and is in ``READY`` status. By default, the function
                            waits for five minutes. Specify 0 or None to skip waiting.
    :return: Single :py:class:`mlflow.entities.model_registry.ModelVersion` object created by
             backend.

    .. code-block:: python
        :caption: Example

        import mlflow.sklearn
        from sklearn.ensemble import RandomForestRegressor

        mlflow.set_tracking_uri("sqlite:////tmp/mlruns.db")
        params = {"n_estimators": 3, "random_state": 42}

        # Log MLflow entities
        with mlflow.start_run() as run:
           rfr = RandomForestRegressor(**params)
           mlflow.log_params(params)
           mlflow.sklearn.log_model(rfr, artifact_path="sklearn-model")

        model_uri = "runs:/{}/sklearn-model".format(run.info.run_id)
        mv = mlflow.register_model(model_uri, "RandomForestRegressionModel")
        print("Name: {}".format(mv.name))
        print("Version: {}".format(mv.version))

    .. code-block:: text
        :caption: Output

        Name: RandomForestRegressionModel
        Version: 1
    """
    client = MlflowClient()
    try:
        create_model_response = client.create_registered_model(name)
        eprint("Successfully registered model '%s'." % create_model_response.name)
    except MlflowException as e:
        if e.error_code == ErrorCode.Name(RESOURCE_ALREADY_EXISTS):
            eprint(
                "Registered model '%s' already exists. Creating a new version of this model..."
                % name
            )
        else:
            raise e

    if RunsArtifactRepository.is_runs_uri(model_uri):
        source = RunsArtifactRepository.get_underlying_uri(model_uri)
        (run_id, _) = RunsArtifactRepository.parse_runs_uri(model_uri)
        create_version_response = client.create_model_version(name, source, run_id)
    else:
        create_version_response = client.create_model_version(
            name, source=model_uri, run_id=None, await_creation_for=await_registration_for
        )
    eprint(
        "Created version '{version}' of model '{model_name}'.".format(
            version=create_version_response.version, model_name=create_version_response.name
        )
    )
    return create_version_response