예제 #1
0
    def __init__(self, service_context, host_creds=None, **kwargs):
        """
        Construct an AzureMLflowModelRegistry object.

        :param service_context: Service context for the AzureML workspace
        :type service_context: azureml._restclient.service_context.ServiceContext
        """
        logger.debug("Initializing the AzureMLflowModelRegistry")
        AzureMLAbstractRestStore.__init__(self, service_context, host_creds)
        RestStore.__init__(self, self.get_host_creds, **kwargs)
예제 #2
0
def _get_rest_store(store_uri, **_):
    def get_default_host_creds():
        return rest_utils.MlflowHostCreds(
            host=store_uri,
            username=os.environ.get(_TRACKING_USERNAME_ENV_VAR),
            password=os.environ.get(_TRACKING_PASSWORD_ENV_VAR),
            token=os.environ.get(_TRACKING_TOKEN_ENV_VAR),
            ignore_tls_verification=os.environ.get(
                _TRACKING_INSECURE_TLS_ENV_VAR) == "true",
        )

    return RestStore(get_default_host_creds)
예제 #3
0
 def setUp(self):
     self.creds = MlflowHostCreds("https://hello")
     self.store = RestStore(lambda: self.creds)
예제 #4
0
class TestRestStore(unittest.TestCase):
    def setUp(self):
        self.creds = MlflowHostCreds("https://hello")
        self.store = RestStore(lambda: self.creds)

    def tearDown(self):
        pass

    def _args(self, host_creds, endpoint, method, json_body, preview=True):
        res = {
            "host_creds":
            host_creds,
            "endpoint":
            ("/api/2.0/preview/mlflow/%s" if preview else "/api/2.0/mlflow/%s")
            % endpoint,
            "method":
            method,
        }
        if method == "GET":
            res["params"] = json.loads(json_body)
        else:
            res["json"] = json.loads(json_body)
        return res

    def _verify_requests(self, http_request, endpoint, method, proto_message):
        json_body = message_to_json(proto_message)
        http_request.assert_any_call(
            **(self._args(self.creds, endpoint, method, json_body)))

    def _verify_all_requests(self, http_request, endpoints, proto_message):
        json_body = message_to_json(proto_message)
        http_request.assert_has_calls([
            mock.call(**(self._args(
                self.creds, endpoint, method, json_body, preview=preview)))
            for endpoint, method, preview in endpoints
        ])

    @mock_http_request
    def test_create_registered_model(self, mock_http):
        tags = [
            RegisteredModelTag(key="key", value="value"),
            RegisteredModelTag(key="anotherKey", value="some other value"),
        ]
        description = "best model ever"
        self.store.create_registered_model("model_1", tags, description)
        self._verify_requests(
            mock_http,
            "registered-models/create",
            "POST",
            CreateRegisteredModel(name="model_1",
                                  tags=[tag.to_proto() for tag in tags],
                                  description=description),
        )

    @mock_http_request
    def test_update_registered_model_name(self, mock_http):
        name = "model_1"
        new_name = "model_2"
        self.store.rename_registered_model(name=name, new_name=new_name)
        self._verify_requests(
            mock_http,
            "registered-models/rename",
            "POST",
            RenameRegisteredModel(name=name, new_name=new_name),
        )

    @mock_http_request
    def test_update_registered_model_description(self, mock_http):
        name = "model_1"
        description = "test model"
        self.store.update_registered_model(name=name, description=description)
        self._verify_requests(
            mock_http,
            "registered-models/update",
            "PATCH",
            UpdateRegisteredModel(name=name, description=description),
        )

    @mock_http_request
    def test_delete_registered_model(self, mock_http):
        name = "model_1"
        self.store.delete_registered_model(name=name)
        self._verify_requests(mock_http, "registered-models/delete", "DELETE",
                              DeleteRegisteredModel(name=name))

    @mock_http_request
    def test_list_registered_model(self, mock_http):
        self.store.list_registered_models(max_results=50, page_token=None)
        self._verify_requests(
            mock_http,
            "registered-models/list",
            "GET",
            ListRegisteredModels(page_token=None, max_results=50),
        )

    @mock_http_request
    def test_search_registered_model(self, mock_http):
        self.store.search_registered_models()
        self._verify_requests(mock_http, "registered-models/search", "GET",
                              SearchRegisteredModels())
        params_list = [
            {
                "filter_string": "model = 'yo'"
            },
            {
                "max_results": 400
            },
            {
                "page_token": "blah"
            },
            {
                "order_by": ["x", "Y"]
            },
        ]
        # test all combination of params
        for sz in [0, 1, 2, 3, 4]:
            for combination in combinations(params_list, sz):
                params = {k: v for d in combination for k, v in d.items()}
                self.store.search_registered_models(**params)
                if "filter_string" in params:
                    params["filter"] = params.pop("filter_string")
                self._verify_requests(mock_http, "registered-models/search",
                                      "GET", SearchRegisteredModels(**params))

    @mock_http_request
    def test_get_registered_model(self, mock_http):
        name = "model_1"
        self.store.get_registered_model(name=name)
        self._verify_requests(mock_http, "registered-models/get", "GET",
                              GetRegisteredModel(name=name))

    @mock_multiple_http_requests
    def test_get_latest_versions(self, mock_multiple_http_requests):
        name = "model_1"
        self.store.get_latest_versions(name=name)
        endpoint = "registered-models/get-latest-versions"
        endpoints = [(endpoint, "POST", False), (endpoint, "GET", True)]
        self._verify_all_requests(mock_multiple_http_requests, endpoints,
                                  GetLatestVersions(name=name))

    @mock_multiple_http_requests
    def test_get_latest_versions_with_stages(self,
                                             mock_multiple_http_requests):
        name = "model_1"
        self.store.get_latest_versions(name=name, stages=["blaah"])
        endpoint = "registered-models/get-latest-versions"
        endpoints = [(endpoint, "POST", False), (endpoint, "GET", True)]
        self._verify_all_requests(
            mock_multiple_http_requests, endpoints,
            GetLatestVersions(name=name, stages=["blaah"]))

    @mock_http_request
    def test_set_registered_model_tag(self, mock_http):
        name = "model_1"
        tag = RegisteredModelTag(key="key", value="value")
        self.store.set_registered_model_tag(name=name, tag=tag)
        self._verify_requests(
            mock_http,
            "registered-models/set-tag",
            "POST",
            SetRegisteredModelTag(name=name, key=tag.key, value=tag.value),
        )

    @mock_http_request
    def test_delete_registered_model_tag(self, mock_http):
        name = "model_1"
        self.store.delete_registered_model_tag(name=name, key="key")
        self._verify_requests(
            mock_http,
            "registered-models/delete-tag",
            "DELETE",
            DeleteRegisteredModelTag(name=name, key="key"),
        )

    @mock_http_request
    def test_create_model_version(self, mock_http):
        self.store.create_model_version("model_1", "path/to/source")
        self._verify_requests(
            mock_http,
            "model-versions/create",
            "POST",
            CreateModelVersion(name="model_1", source="path/to/source"),
        )
        # test optional fields
        run_id = uuid.uuid4().hex
        tags = [
            ModelVersionTag(key="key", value="value"),
            ModelVersionTag(key="anotherKey", value="some other value"),
        ]
        run_link = "localhost:5000/path/to/run"
        description = "version description"
        self.store.create_model_version(
            "model_1",
            "path/to/source",
            run_id,
            tags,
            run_link=run_link,
            description=description,
        )
        self._verify_requests(
            mock_http,
            "model-versions/create",
            "POST",
            CreateModelVersion(
                name="model_1",
                source="path/to/source",
                run_id=run_id,
                run_link=run_link,
                tags=[tag.to_proto() for tag in tags],
                description=description,
            ),
        )

    @mock_http_request
    def test_transition_model_version_stage(self, mock_http):
        name = "model_1"
        version = "5"
        self.store.transition_model_version_stage(
            name=name,
            version=version,
            stage="prod",
            archive_existing_versions=True)
        self._verify_requests(
            mock_http,
            "model-versions/transition-stage",
            "POST",
            TransitionModelVersionStage(name=name,
                                        version=version,
                                        stage="prod",
                                        archive_existing_versions=True),
        )

    @mock_http_request
    def test_update_model_version_decription(self, mock_http):
        name = "model_1"
        version = "5"
        description = "test model version"
        self.store.update_model_version(name=name,
                                        version=version,
                                        description=description)
        self._verify_requests(
            mock_http,
            "model-versions/update",
            "PATCH",
            UpdateModelVersion(name=name,
                               version=version,
                               description="test model version"),
        )

    @mock_http_request
    def test_delete_model_version(self, mock_http):
        name = "model_1"
        version = "12"
        self.store.delete_model_version(name=name, version=version)
        self._verify_requests(
            mock_http,
            "model-versions/delete",
            "DELETE",
            DeleteModelVersion(name=name, version=version),
        )

    @mock_http_request
    def test_get_model_version_details(self, mock_http):
        name = "model_11"
        version = "8"
        self.store.get_model_version(name=name, version=version)
        self._verify_requests(mock_http, "model-versions/get", "GET",
                              GetModelVersion(name=name, version=version))

    @mock_http_request
    def test_get_model_version_download_uri(self, mock_http):
        name = "model_11"
        version = "8"
        self.store.get_model_version_download_uri(name=name, version=version)
        self._verify_requests(
            mock_http,
            "model-versions/get-download-uri",
            "GET",
            GetModelVersionDownloadUri(name=name, version=version),
        )

    @mock_http_request
    def test_search_model_versions(self, mock_http):
        self.store.search_model_versions(filter_string="name='model_12'")
        self._verify_requests(mock_http, "model-versions/search", "GET",
                              SearchModelVersions(filter="name='model_12'"))

    @mock_http_request
    def test_set_model_version_tag(self, mock_http):
        name = "model_1"
        tag = ModelVersionTag(key="key", value="value")
        self.store.set_model_version_tag(name=name, version="1", tag=tag)
        self._verify_requests(
            mock_http,
            "model-versions/set-tag",
            "POST",
            SetModelVersionTag(name=name,
                               version="1",
                               key=tag.key,
                               value=tag.value),
        )

    @mock_http_request
    def test_delete_model_version_tag(self, mock_http):
        name = "model_1"
        self.store.delete_model_version_tag(name=name, version="1", key="key")
        self._verify_requests(
            mock_http,
            "model-versions/delete-tag",
            "DELETE",
            DeleteModelVersionTag(name=name, version="1", key="key"),
        )
예제 #5
0
class TestRestStore(unittest.TestCase):
    def setUp(self):
        self.creds = MlflowHostCreds('https://hello')
        self.store = RestStore(lambda: self.creds)

    def tearDown(self):
        pass

    def _args(self, host_creds, endpoint, method, json_body):
        res = {
            'host_creds': host_creds,
            'endpoint': "/api/2.0/preview/mlflow/%s" % endpoint,
            'method': method
        }
        if method == "GET":
            res["params"] = json.loads(json_body)
        else:
            res["json"] = json.loads(json_body)
        return res

    def _verify_requests(self, http_request, endpoint, method, proto_message):
        print(http_request.call_args_list)
        json_body = message_to_json(proto_message)
        http_request.assert_any_call(
            **(self._args(self.creds, endpoint, method, json_body)))

    @mock.patch('mlflow.utils.rest_utils.http_request')
    def test_create_registered_model(self, mock_http):
        self.store.create_registered_model("model_1")
        self._verify_requests(mock_http, "registered-models/create", "POST",
                              CreateRegisteredModel(name="model_1"))

    @mock.patch('mlflow.utils.rest_utils.http_request')
    def test_update_registered_model_name(self, mock_http):
        name = "model_1"
        new_name = "model_2"
        self.store.rename_registered_model(name=name, new_name=new_name)
        self._verify_requests(
            mock_http, "registered-models/rename", "POST",
            RenameRegisteredModel(name=name, new_name=new_name))

    @mock.patch('mlflow.utils.rest_utils.http_request')
    def test_update_registered_model_description(self, mock_http):
        name = "model_1"
        description = "test model"
        self.store.update_registered_model(name=name, description=description)
        self._verify_requests(
            mock_http, "registered-models/update", "PATCH",
            UpdateRegisteredModel(name=name, description=description))

    @mock.patch('mlflow.utils.rest_utils.http_request')
    def test_delete_registered_model(self, mock_http):
        name = "model_1"
        self.store.delete_registered_model(name=name)
        self._verify_requests(mock_http, "registered-models/delete", "DELETE",
                              DeleteRegisteredModel(name=name))

    @mock.patch('mlflow.utils.rest_utils.http_request')
    def test_list_registered_model(self, mock_http):
        self.store.list_registered_models()
        self._verify_requests(mock_http, "registered-models/list", "GET",
                              ListRegisteredModels())

    @mock.patch('mlflow.utils.rest_utils.http_request')
    def test_get_registered_model(self, mock_http):
        name = "model_1"
        self.store.get_registered_model(name=name)
        self._verify_requests(mock_http, "registered-models/get", "GET",
                              GetRegisteredModel(name=name))

    @mock.patch('mlflow.utils.rest_utils.http_request')
    def test_get_latest_versions(self, mock_http):
        name = "model_1"
        self.store.get_latest_versions(name=name)
        self._verify_requests(mock_http,
                              "registered-models/get-latest-versions", "GET",
                              GetLatestVersions(name=name))

    @mock.patch('mlflow.utils.rest_utils.http_request')
    def test_get_latest_versions_with_stages(self, mock_http):
        name = "model_1"
        self.store.get_latest_versions(name=name, stages=["blaah"])
        self._verify_requests(mock_http,
                              "registered-models/get-latest-versions", "GET",
                              GetLatestVersions(name=name, stages=["blaah"]))

    @mock.patch('mlflow.utils.rest_utils.http_request')
    def test_create_model_version(self, mock_http):
        run_id = uuid.uuid4().hex
        self.store.create_model_version("model_1", "path/to/source", run_id)
        self._verify_requests(
            mock_http, "model-versions/create", "POST",
            CreateModelVersion(name="model_1",
                               source="path/to/source",
                               run_id=run_id))

    @mock.patch('mlflow.utils.rest_utils.http_request')
    def test_transition_model_version_stage(self, mock_http):
        name = "model_1"
        version = "5"
        self.store.transition_model_version_stage(
            name=name,
            version=version,
            stage="prod",
            archive_existing_versions=True)
        self._verify_requests(
            mock_http, "model-versions/transition-stage", "POST",
            TransitionModelVersionStage(name=name,
                                        version=version,
                                        stage="prod",
                                        archive_existing_versions=True))

    @mock.patch('mlflow.utils.rest_utils.http_request')
    def test_update_model_version_decription(self, mock_http):
        name = "model_1"
        version = "5"
        description = "test model version"
        self.store.update_model_version(name=name,
                                        version=version,
                                        description=description)
        self._verify_requests(
            mock_http, "model-versions/update", "PATCH",
            UpdateModelVersion(name=name,
                               version=version,
                               description="test model version"))

    @mock.patch('mlflow.utils.rest_utils.http_request')
    def test_delete_model_version(self, mock_http):
        name = "model_1"
        version = "12"
        self.store.delete_model_version(name=name, version=version)
        self._verify_requests(mock_http, "model-versions/delete", "DELETE",
                              DeleteModelVersion(name=name, version=version))

    @mock.patch('mlflow.utils.rest_utils.http_request')
    def test_get_model_version_details(self, mock_http):
        name = "model_11"
        version = "8"
        self.store.get_model_version(name=name, version=version)
        self._verify_requests(mock_http, "model-versions/get", "GET",
                              GetModelVersion(name=name, version=version))

    @mock.patch('mlflow.utils.rest_utils.http_request')
    def test_get_model_version_download_uri(self, mock_http):
        name = "model_11"
        version = "8"
        self.store.get_model_version_download_uri(name=name, version=version)
        self._verify_requests(
            mock_http, "model-versions/get-download-uri", "GET",
            GetModelVersionDownloadUri(name=name, version=version))

    @mock.patch('mlflow.utils.rest_utils.http_request')
    def test_search_model_versions(self, mock_http):
        self.store.search_model_versions(filter_string="name='model_12'")
        self._verify_requests(mock_http, "model-versions/search", "GET",
                              SearchModelVersions(filter="name='model_12'"))
예제 #6
0
def _get_databricks_rest_store(store_uri, **_):
    return RestStore(lambda: get_databricks_host_creds(store_uri))
예제 #7
0
파일: utils.py 프로젝트: tnixon/mlflow
def _get_databricks_rest_store(store_uri, **_):
    return RestStore(partial(get_databricks_host_creds, store_uri))
예제 #8
0
파일: utils.py 프로젝트: tnixon/mlflow
def _get_rest_store(store_uri, **_):
    return RestStore(partial(get_default_host_creds, store_uri))
class TestRestStore(unittest.TestCase):
    def setUp(self):
        self.creds = MlflowHostCreds('https://hello')
        self.store = RestStore(lambda: self.creds)

    def tearDown(self):
        pass

    def _args(self, host_creds, endpoint, method, json_body):
        res = {
            'host_creds': host_creds,
            'endpoint': "/api/2.0/preview/mlflow/%s" % endpoint,
            'method': method
        }
        if method == "GET":
            res["params"] = json.loads(json_body)
        else:
            res["json"] = json.loads(json_body)
        return res

    def _verify_requests(self, http_request, endpoint, method, proto_message):
        print(http_request.call_args_list)
        json_body = message_to_json(proto_message)
        http_request.assert_any_call(
            **(self._args(self.creds, endpoint, method, json_body)))

    @mock.patch('mlflow.utils.rest_utils.http_request')
    def test_create_registered_model(self, mock_http):
        self.store.create_registered_model("model_1")
        self._verify_requests(mock_http, "registered-models/create", "POST",
                              CreateRegisteredModel(name="model_1"))

    @mock.patch('mlflow.utils.rest_utils.http_request')
    def test_update_registered_model_name(self, mock_http):
        rm = RegisteredModel("model_1")
        self.store.update_registered_model(registered_model=rm,
                                           new_name="model_2")
        self._verify_requests(
            mock_http, "registered-models/update", "PATCH",
            UpdateRegisteredModel(registered_model=rm.to_proto(),
                                  name="model_2"))

    @mock.patch('mlflow.utils.rest_utils.http_request')
    def test_update_registered_model_description(self, mock_http):
        rm = RegisteredModel("model_1")
        self.store.update_registered_model(registered_model=rm,
                                           description="test model")
        self._verify_requests(
            mock_http, "registered-models/update", "PATCH",
            UpdateRegisteredModel(registered_model=rm.to_proto(),
                                  description="test model"))

    @mock.patch('mlflow.utils.rest_utils.http_request')
    def test_update_registered_model_all(self, mock_http):
        rm = RegisteredModel("model_1")
        self.store.update_registered_model(registered_model=rm,
                                           new_name="model_3",
                                           description="rename and describe")
        self._verify_requests(
            mock_http, "registered-models/update", "PATCH",
            UpdateRegisteredModel(registered_model=rm.to_proto(),
                                  name="model_3",
                                  description="rename and describe"))

    @mock.patch('mlflow.utils.rest_utils.http_request')
    def test_delete_registered_model(self, mock_http):
        rm = RegisteredModel("model_1")
        self.store.delete_registered_model(registered_model=rm)
        self._verify_requests(
            mock_http, "registered-models/delete", "DELETE",
            DeleteRegisteredModel(registered_model=rm.to_proto()))

    @mock.patch('mlflow.utils.rest_utils.http_request')
    def test_list_registered_model(self, mock_http):
        self.store.list_registered_models()
        self._verify_requests(mock_http, "registered-models/list", "GET",
                              ListRegisteredModels())

    @mock.patch('mlflow.utils.rest_utils.http_request')
    def test_get_registered_model_detailed(self, mock_http):
        rm = RegisteredModel("model_1")
        self.store.get_registered_model_details(registered_model=rm)
        self._verify_requests(
            mock_http, "registered-models/get-details", "POST",
            GetRegisteredModelDetails(registered_model=rm.to_proto()))

    @mock.patch('mlflow.utils.rest_utils.http_request')
    def test_get_latest_versions(self, mock_http):
        rm = RegisteredModel("model_1")
        self.store.get_latest_versions(registered_model=rm)
        self._verify_requests(
            mock_http, "registered-models/get-latest-versions", "POST",
            GetLatestVersions(registered_model=rm.to_proto()))

    @mock.patch('mlflow.utils.rest_utils.http_request')
    def test_get_latest_versions_with_stages(self, mock_http):
        rm = RegisteredModel("model_1")
        self.store.get_latest_versions(registered_model=rm, stages=["blaah"])
        self._verify_requests(
            mock_http, "registered-models/get-latest-versions", "POST",
            GetLatestVersions(registered_model=rm.to_proto(),
                              stages=["blaah"]))

    @mock.patch('mlflow.utils.rest_utils.http_request')
    def test_create_model_version(self, mock_http):
        run_id = uuid.uuid4().hex
        self.store.create_model_version("model_1", "path/to/source", run_id)
        self._verify_requests(
            mock_http, "model-versions/create", "POST",
            CreateModelVersion(name="model_1",
                               source="path/to/source",
                               run_id=run_id))

    @mock.patch('mlflow.utils.rest_utils.http_request')
    def test_update_model_version_stage(self, mock_http):
        rm = RegisteredModel("model_1")
        mv = ModelVersion(rm, 5)
        self.store.update_model_version(model_version=mv, stage="prod")
        self._verify_requests(
            mock_http, "model-versions/update", "PATCH",
            UpdateModelVersion(model_version=mv.to_proto(), stage="prod"))

    @mock.patch('mlflow.utils.rest_utils.http_request')
    def test_update_model_version_decription(self, mock_http):
        rm = RegisteredModel("model_1")
        mv = ModelVersion(rm, 5)
        self.store.update_model_version(model_version=mv,
                                        description="test model version")
        self._verify_requests(
            mock_http, "model-versions/update", "PATCH",
            UpdateModelVersion(model_version=mv.to_proto(),
                               description="test model version"))

    @mock.patch('mlflow.utils.rest_utils.http_request')
    def test_update_model_version_all(self, mock_http):
        rm = RegisteredModel("model_1")
        mv = ModelVersion(rm, 5)
        self.store.update_model_version(model_version=mv,
                                        stage="5%",
                                        description="A|B test")
        self._verify_requests(
            mock_http, "model-versions/update", "PATCH",
            UpdateModelVersion(model_version=mv.to_proto(),
                               stage="5%",
                               description="A|B test"))

    @mock.patch('mlflow.utils.rest_utils.http_request')
    def test_delete_model_version(self, mock_http):
        rm = RegisteredModel("model_1")
        mv = ModelVersion(rm, 12)
        self.store.delete_model_version(model_version=mv)
        self._verify_requests(mock_http, "model-versions/delete", "DELETE",
                              DeleteModelVersion(model_version=mv.to_proto()))

    @mock.patch('mlflow.utils.rest_utils.http_request')
    def test_get_model_version_details(self, mock_http):
        rm = RegisteredModel("model_11")
        mv = ModelVersion(rm, 8)
        self.store.get_model_version_details(model_version=mv)
        self._verify_requests(
            mock_http, "model-versions/get-details", "POST",
            GetModelVersionDetails(model_version=mv.to_proto()))

    @mock.patch('mlflow.utils.rest_utils.http_request')
    def test_get_model_version_download_uri(self, mock_http):
        rm = RegisteredModel("model_11")
        mv = ModelVersion(rm, 8)
        self.store.get_model_version_download_uri(model_version=mv)
        self._verify_requests(
            mock_http, "model-versions/get-download-uri", "POST",
            GetModelVersionDownloadUri(model_version=mv.to_proto()))

    @mock.patch('mlflow.utils.rest_utils.http_request')
    def test_search_model_versions(self, mock_http):
        self.store.search_model_versions(filter_string="name='model_12'")
        self._verify_requests(mock_http, "model-versions/search", "GET",
                              SearchModelVersions(filter="name='model_12'"))

    @mock.patch('mlflow.utils.rest_utils.http_request')
    def test_get_model_version_stages(self, mock_http):
        rm = RegisteredModel("model_11")
        mv = ModelVersion(rm, 8)
        self.store.get_model_version_stages(model_version=mv)
        self._verify_requests(
            mock_http, "model-versions/get-stages", "POST",
            GetModelVersionStages(model_version=mv.to_proto()))