コード例 #1
0
    def test_failed_http_request(self, request):
        response = mock.MagicMock
        response.status_code = 404
        response.text = '{"error_code": "RESOURCE_DOES_NOT_EXIST", "message": "No experiment"}'
        request.return_value = response

        store = RestStore(lambda: MlflowHostCreds("https://hello"))
        with pytest.raises(MlflowException) as cm:
            store.list_experiments()
        assert "RESOURCE_DOES_NOT_EXIST: No experiment" in str(cm.value)
コード例 #2
0
    def __init__(self, service_context, host_creds=None, **kwargs):
        """
        Construct an AzureMLRestStore object.

        :param service_context: Service context for the AzureML workspace
        :type service_context: azureml._restclient.service_context.ServiceContext
        """
        logger.debug("Initializing the AzureMLRestStore")
        AzureMLAbstractRestStore.__init__(self, service_context, host_creds)
        RestStore.__init__(self, self.get_host_creds, **kwargs)
コード例 #3
0
def test_get_host_creds_from_default_store_rest_store():
    with mock.patch("mlflow.tracking._tracking_service.utils._get_store"
                    ) as get_store_mock:
        get_store_mock.return_value = RestStore(
            lambda: MlflowHostCreds("http://host"))
        assert isinstance(_get_host_creds_from_default_store()(),
                          MlflowHostCreds)
コード例 #4
0
    def test_response_with_unknown_fields(self, request):
        experiment_json = {
            "experiment_id": "1",
            "name": "My experiment",
            "artifact_location": "foo",
            "lifecycle_stage": "deleted",
            "OMG_WHAT_IS_THIS_FIELD": "Hooly cow",
        }

        response = mock.MagicMock
        response.status_code = 200
        experiments = {"experiments": [experiment_json]}
        response.text = json.dumps(experiments)
        request.return_value = response

        store = RestStore(lambda: MlflowHostCreds("https://hello"))
        experiments = store.list_experiments()
        assert len(experiments) == 1
        assert experiments[0].name == "My experiment"
コード例 #5
0
ファイル: utils.py プロジェクト: yusufazishty/mlflow
def _get_rest_store(store_uri, **_):
    def get_default_host_creds():
        return rest_utils.MlflowHostCreds(
            host=store_uri,
            username=os.environ.get(_TRACKING_USERNAME_ENV_VAR),
            password=os.environ.get(_TRACKING_PASSWORD_ENV_VAR),
            token=os.environ.get(_TRACKING_TOKEN_ENV_VAR),
            ignore_tls_verification=os.environ.get(_TRACKING_INSECURE_TLS_ENV_VAR) == 'true',
        )

    return RestStore(get_default_host_creds)
コード例 #6
0
    def test_successful_http_request(self, request):
        def mock_request(**kwargs):
            # Filter out None arguments
            kwargs = dict((k, v) for k, v in six.iteritems(kwargs) if v is not None)
            assert kwargs == {
                "method": "GET",
                "params": {"view_type": "ACTIVE_ONLY"},
                "url": "https://hello/api/2.0/mlflow/experiments/list",
                "headers": _DEFAULT_HEADERS,
                "verify": True,
            }
            response = mock.MagicMock
            response.status_code = 200
            response.text = '{"experiments": [{"name": "Exp!", "lifecycle_stage": "active"}]}'
            return response

        request.side_effect = mock_request

        store = RestStore(lambda: MlflowHostCreds("https://hello"))
        experiments = store.list_experiments()
        assert experiments[0].name == "Exp!"
コード例 #7
0
    def get_db_store(self):
        try:
            tracking_uri = mlflow.get_tracking_uri()
        except ImportError:
            logger.warning(VERSION_WARNING.format("mlflow.get_tracking_uri"))
            tracking_uri = mlflow.tracking.get_tracking_uri()

        from mlflow.utils.databricks_utils import get_databricks_host_creds
        try:
            # If get_db_info_from_uri exists, it means mlflow 1.10 or above
            from mlflow.utils.uri import get_db_info_from_uri
            profile, path = get_db_info_from_uri("databricks")

            return RestStore(lambda: get_databricks_host_creds(tracking_uri))
        except ImportError:
            try:
                from mlflow.utils.uri import get_db_profile_from_uri
            except ImportError:
                logger.warning(VERSION_WARNING.format("from mlflow"))
                from mlflow.tracking.utils import get_db_profile_from_uri

            profile = get_db_profile_from_uri("databricks")
            logger.info("tracking uri: {} and profile: {}".format(tracking_uri, profile))
            return RestStore(lambda: get_databricks_host_creds(profile))
コード例 #8
0
    def test_successful_http_request(self, request):
        def mock_request(*args, **kwargs):
            # Filter out None arguments
            assert args == ("GET",
                            "https://hello/api/2.0/mlflow/experiments/list")
            kwargs = dict((k, v) for k, v in kwargs.items() if v is not None)
            assert kwargs == {
                "params": {
                    "view_type": "ACTIVE_ONLY"
                },
                "headers": DefaultRequestHeaderProvider().request_headers(),
                "verify": True,
                "timeout": 120,
            }
            response = mock.MagicMock()
            response.status_code = 200
            response.text = '{"experiments": [{"name": "Exp!", "lifecycle_stage": "active"}]}'
            return response

        request.side_effect = mock_request

        store = RestStore(lambda: MlflowHostCreds("https://hello"))
        experiments = store.list_experiments()
        assert experiments[0].name == "Exp!"
コード例 #9
0
ファイル: utils.py プロジェクト: anjosma/mlflow_experiments
def _get_rest_store(store_uri, **_):
    return RestStore(partial(_get_default_host_creds, store_uri))
コード例 #10
0
    def test_requestor(self, request):
        response = mock.MagicMock
        response.status_code = 200
        response.text = "{}"
        request.return_value = response

        creds = MlflowHostCreds("https://hello")
        store = RestStore(lambda: creds)

        user_name = "mock user"
        source_name = "rest test"

        source_name_patch = mock.patch(
            "mlflow.tracking.context.default_context._get_source_name",
            return_value=source_name)
        source_type_patch = mock.patch(
            "mlflow.tracking.context.default_context._get_source_type",
            return_value=SourceType.LOCAL,
        )
        with mock.patch(
                "mlflow.utils.rest_utils.http_request"
        ) as mock_http, mock.patch(
                "mlflow.tracking._tracking_service.utils._get_store",
                return_value=store), mock.patch(
                    "mlflow.tracking.context.default_context._get_user",
                    return_value=user_name), mock.patch(
                        "time.time", return_value=13579
                    ), source_name_patch, source_type_patch:
            with mlflow.start_run(experiment_id="43"):
                cr_body = message_to_json(
                    CreateRun(
                        experiment_id="43",
                        user_id=user_name,
                        start_time=13579000,
                        tags=[
                            ProtoRunTag(key="mlflow.source.name",
                                        value=source_name),
                            ProtoRunTag(key="mlflow.source.type",
                                        value="LOCAL"),
                            ProtoRunTag(key="mlflow.user", value=user_name),
                        ],
                    ))
                expected_kwargs = self._args(creds, "runs/create", "POST",
                                             cr_body)

                assert mock_http.call_count == 1
                actual_kwargs = mock_http.call_args[1]

                # Test the passed tag values separately from the rest of the request
                # Tag order is inconsistent on Python 2 and 3, but the order does not matter
                expected_tags = expected_kwargs["json"].pop("tags")
                actual_tags = actual_kwargs["json"].pop("tags")
                assert sorted(expected_tags, key=lambda t: t["key"]) == sorted(
                    actual_tags, key=lambda t: t["key"])
                assert expected_kwargs == actual_kwargs

        with mock.patch("mlflow.utils.rest_utils.http_request") as mock_http:
            store.log_param("some_uuid", Param("k1", "v1"))
            body = message_to_json(
                LogParam(run_uuid="some_uuid",
                         run_id="some_uuid",
                         key="k1",
                         value="v1"))
            self._verify_requests(mock_http, creds, "runs/log-parameter",
                                  "POST", body)

        with mock.patch("mlflow.utils.rest_utils.http_request") as mock_http:
            store.set_experiment_tag("some_id",
                                     ExperimentTag("t1", "abcd" * 1000))
            body = message_to_json(
                SetExperimentTag(experiment_id="some_id",
                                 key="t1",
                                 value="abcd" * 1000))
            self._verify_requests(mock_http, creds,
                                  "experiments/set-experiment-tag", "POST",
                                  body)

        with mock.patch("mlflow.utils.rest_utils.http_request") as mock_http:
            store.set_tag("some_uuid", RunTag("t1", "abcd" * 1000))
            body = message_to_json(
                SetTag(run_uuid="some_uuid",
                       run_id="some_uuid",
                       key="t1",
                       value="abcd" * 1000))
            self._verify_requests(mock_http, creds, "runs/set-tag", "POST",
                                  body)

        with mock.patch("mlflow.utils.rest_utils.http_request") as mock_http:
            store.delete_tag("some_uuid", "t1")
            body = message_to_json(DeleteTag(run_id="some_uuid", key="t1"))
            self._verify_requests(mock_http, creds, "runs/delete-tag", "POST",
                                  body)

        with mock.patch("mlflow.utils.rest_utils.http_request") as mock_http:
            store.log_metric("u2", Metric("m1", 0.87, 12345, 3))
            body = message_to_json(
                LogMetric(run_uuid="u2",
                          run_id="u2",
                          key="m1",
                          value=0.87,
                          timestamp=12345,
                          step=3))
            self._verify_requests(mock_http, creds, "runs/log-metric", "POST",
                                  body)

        with mock.patch("mlflow.utils.rest_utils.http_request") as mock_http:
            metrics = [
                Metric("m1", 0.87, 12345, 0),
                Metric("m2", 0.49, 12345, -1),
                Metric("m3", 0.58, 12345, 2),
            ]
            params = [Param("p1", "p1val"), Param("p2", "p2val")]
            tags = [RunTag("t1", "t1val"), RunTag("t2", "t2val")]
            store.log_batch(run_id="u2",
                            metrics=metrics,
                            params=params,
                            tags=tags)
            metric_protos = [metric.to_proto() for metric in metrics]
            param_protos = [param.to_proto() for param in params]
            tag_protos = [tag.to_proto() for tag in tags]
            body = message_to_json(
                LogBatch(run_id="u2",
                         metrics=metric_protos,
                         params=param_protos,
                         tags=tag_protos))
            self._verify_requests(mock_http, creds, "runs/log-batch", "POST",
                                  body)

        with mock.patch("mlflow.utils.rest_utils.http_request") as mock_http:
            store.delete_run("u25")
            self._verify_requests(mock_http, creds, "runs/delete", "POST",
                                  message_to_json(DeleteRun(run_id="u25")))

        with mock.patch("mlflow.utils.rest_utils.http_request") as mock_http:
            store.restore_run("u76")
            self._verify_requests(mock_http, creds, "runs/restore", "POST",
                                  message_to_json(RestoreRun(run_id="u76")))

        with mock.patch("mlflow.utils.rest_utils.http_request") as mock_http:
            store.delete_experiment("0")
            self._verify_requests(
                mock_http,
                creds,
                "experiments/delete",
                "POST",
                message_to_json(DeleteExperiment(experiment_id="0")),
            )

        with mock.patch("mlflow.utils.rest_utils.http_request") as mock_http:
            store.restore_experiment("0")
            self._verify_requests(
                mock_http,
                creds,
                "experiments/restore",
                "POST",
                message_to_json(RestoreExperiment(experiment_id="0")),
            )

        with mock.patch("mlflow.utils.rest_utils.http_request") as mock_http:
            response = mock.MagicMock
            response.text = '{"runs": ["1a", "2b", "3c"], "next_page_token": "67890fghij"}'
            mock_http.return_value = response
            result = store.search_runs(
                ["0", "1"],
                "params.p1 = 'a'",
                ViewType.ACTIVE_ONLY,
                max_results=10,
                order_by=["a"],
                page_token="12345abcde",
            )

            expected_message = SearchRuns(
                experiment_ids=["0", "1"],
                filter="params.p1 = 'a'",
                run_view_type=ViewType.to_proto(ViewType.ACTIVE_ONLY),
                max_results=10,
                order_by=["a"],
                page_token="12345abcde",
            )
            self._verify_requests(mock_http, creds, "runs/search", "POST",
                                  message_to_json(expected_message))
            assert result.token == "67890fghij"

        with mock.patch("mlflow.utils.rest_utils.http_request") as mock_http:
            run_id = "run_id"
            m = Model(artifact_path="model/path",
                      run_id="run_id",
                      flavors={"tf": "flavor body"})
            result = store.record_logged_model("run_id", m)
            expected_message = LogModel(run_id=run_id, model_json=m.to_json())
            self._verify_requests(mock_http, creds, "runs/log-model", "POST",
                                  message_to_json(expected_message))