Пример #1
0
 def delete_run(self, run_id):
     req_body = message_to_json(DeleteRun(run_id=run_id))
     self._call_endpoint(DeleteRun, req_body)
Пример #2
0
 def restore_run(self, run_id):
     req_body = message_to_json(RestoreRun(run_id=run_id))
     self._call_endpoint(RestoreRun, req_body)
Пример #3
0
 def restore_experiment(self, experiment_id):
     req_body = message_to_json(
         RestoreExperiment(experiment_id=str(experiment_id)))
     self._call_endpoint(RestoreExperiment, req_body)
Пример #4
0
 def rename_experiment(self, experiment_id, new_name):
     req_body = message_to_json(
         UpdateExperiment(experiment_id=str(experiment_id),
                          new_name=new_name))
     self._call_endpoint(UpdateExperiment, req_body)
Пример #5
0
def _wrap_response(response_message):
    response = Response(mimetype="application/json")
    response.set_data(message_to_json(response_message))
    return response
Пример #6
0
 def delete_experiment(self, experiment_id):
     req_body = message_to_json(
         DeleteExperiment(experiment_id=str(experiment_id)))
     self._call_endpoint(DeleteExperiment, req_body)
Пример #7
0
    def test_get_experiment_by_name(self, store_class):
        creds = MlflowHostCreds("https://hello")
        store = store_class(lambda: creds)
        with mock.patch("mlflow.utils.rest_utils.http_request") as mock_http:
            response = mock.MagicMock
            response.status_code = 200
            experiment = Experiment(
                experiment_id="123",
                name="abc",
                artifact_location="/abc",
                lifecycle_stage=LifecycleStage.ACTIVE,
            )
            response.text = json.dumps({
                "experiment":
                json.loads(message_to_json(experiment.to_proto()))
            })
            mock_http.return_value = response
            result = store.get_experiment_by_name("abc")
            expected_message0 = GetExperimentByName(experiment_name="abc")
            self._verify_requests(
                mock_http,
                creds,
                "experiments/get-by-name",
                "GET",
                message_to_json(expected_message0),
            )
            assert result.experiment_id == experiment.experiment_id
            assert result.name == experiment.name
            assert result.artifact_location == experiment.artifact_location
            assert result.lifecycle_stage == experiment.lifecycle_stage
            # Test GetExperimentByName against nonexistent experiment
            mock_http.reset_mock()
            nonexistent_exp_response = mock.MagicMock
            nonexistent_exp_response.status_code = 404
            nonexistent_exp_response.text = MlflowException(
                "Exp doesn't exist!",
                RESOURCE_DOES_NOT_EXIST).serialize_as_json()
            mock_http.return_value = nonexistent_exp_response
            assert store.get_experiment_by_name(
                "nonexistent-experiment") is None
            expected_message1 = GetExperimentByName(
                experiment_name="nonexistent-experiment")
            self._verify_requests(
                mock_http,
                creds,
                "experiments/get-by-name",
                "GET",
                message_to_json(expected_message1),
            )
            assert mock_http.call_count == 1

            # Test REST client behavior against a mocked old server, which has handler for
            # ListExperiments but not GetExperimentByName
            mock_http.reset_mock()
            list_exp_response = mock.MagicMock
            list_exp_response.text = json.dumps({
                "experiments":
                [json.loads(message_to_json(experiment.to_proto()))]
            })
            list_exp_response.status_code = 200

            def response_fn(*args, **kwargs):
                # pylint: disable=unused-argument
                if kwargs.get("endpoint"
                              ) == "/api/2.0/mlflow/experiments/get-by-name":
                    raise MlflowException(
                        "GetExperimentByName is not implemented",
                        ENDPOINT_NOT_FOUND)
                else:
                    return list_exp_response

            mock_http.side_effect = response_fn
            result = store.get_experiment_by_name("abc")
            expected_message2 = ListExperiments(view_type=ViewType.ALL)
            self._verify_requests(
                mock_http,
                creds,
                "experiments/get-by-name",
                "GET",
                message_to_json(expected_message0),
            )
            self._verify_requests(mock_http, creds, "experiments/list", "GET",
                                  message_to_json(expected_message2))
            assert result.experiment_id == experiment.experiment_id
            assert result.name == experiment.name
            assert result.artifact_location == experiment.artifact_location
            assert result.lifecycle_stage == experiment.lifecycle_stage

            # Verify that REST client won't fall back to ListExperiments for 429 errors (hitting
            # rate limits)
            mock_http.reset_mock()

            def rate_limit_response_fn(*args, **kwargs):
                # pylint: disable=unused-argument
                raise MlflowException("Hit rate limit on GetExperimentByName",
                                      REQUEST_LIMIT_EXCEEDED)

            mock_http.side_effect = rate_limit_response_fn
            with pytest.raises(MlflowException) as exc_info:
                store.get_experiment_by_name("imspamming")
            assert exc_info.value.error_code == ErrorCode.Name(
                REQUEST_LIMIT_EXCEEDED)
            assert mock_http.call_count == 1
Пример #8
0
 def _get_read_credentials(self, run_id, path=None):
     json_body = message_to_json(
         GetCredentialsForRead(run_id=run_id, path=path))
     return self._call_endpoint(DatabricksMlflowArtifactsService,
                                GetCredentialsForRead, json_body)
Пример #9
0
 def record_logged_model(self, run_id, mlflow_model):
     req_body = message_to_json(LogModel(run_id=run_id, model_json=mlflow_model.to_json()))
     self._call_endpoint(LogModel, req_body)
Пример #10
0
    def test_requestor(self, request):
        response = mock.MagicMock
        response.status_code = 200
        response.text = "{}"
        request.return_value = response

        creds = MlflowHostCreds("https://hello")
        store = RestStore(lambda: creds)

        user_name = "mock user"
        source_name = "rest test"

        source_name_patch = mock.patch(
            "mlflow.tracking.context.default_context._get_source_name",
            return_value=source_name)
        source_type_patch = mock.patch(
            "mlflow.tracking.context.default_context._get_source_type",
            return_value=SourceType.LOCAL,
        )
        with mock.patch(
                "mlflow.utils.rest_utils.http_request"
        ) as mock_http, mock.patch(
                "mlflow.tracking._tracking_service.utils._get_store",
                return_value=store), mock.patch(
                    "mlflow.tracking.context.default_context._get_user",
                    return_value=user_name), mock.patch(
                        "time.time", return_value=13579
                    ), source_name_patch, source_type_patch:
            with mlflow.start_run(experiment_id="43"):
                cr_body = message_to_json(
                    CreateRun(
                        experiment_id="43",
                        user_id=user_name,
                        start_time=13579000,
                        tags=[
                            ProtoRunTag(key="mlflow.source.name",
                                        value=source_name),
                            ProtoRunTag(key="mlflow.source.type",
                                        value="LOCAL"),
                            ProtoRunTag(key="mlflow.user", value=user_name),
                        ],
                    ))
                expected_kwargs = self._args(creds, "runs/create", "POST",
                                             cr_body)

                assert mock_http.call_count == 1
                actual_kwargs = mock_http.call_args[1]

                # Test the passed tag values separately from the rest of the request
                # Tag order is inconsistent on Python 2 and 3, but the order does not matter
                expected_tags = expected_kwargs["json"].pop("tags")
                actual_tags = actual_kwargs["json"].pop("tags")
                assert sorted(expected_tags, key=lambda t: t["key"]) == sorted(
                    actual_tags, key=lambda t: t["key"])
                assert expected_kwargs == actual_kwargs

        with mock.patch("mlflow.utils.rest_utils.http_request") as mock_http:
            store.log_param("some_uuid", Param("k1", "v1"))
            body = message_to_json(
                LogParam(run_uuid="some_uuid",
                         run_id="some_uuid",
                         key="k1",
                         value="v1"))
            self._verify_requests(mock_http, creds, "runs/log-parameter",
                                  "POST", body)

        with mock.patch("mlflow.utils.rest_utils.http_request") as mock_http:
            store.set_experiment_tag("some_id",
                                     ExperimentTag("t1", "abcd" * 1000))
            body = message_to_json(
                SetExperimentTag(experiment_id="some_id",
                                 key="t1",
                                 value="abcd" * 1000))
            self._verify_requests(mock_http, creds,
                                  "experiments/set-experiment-tag", "POST",
                                  body)

        with mock.patch("mlflow.utils.rest_utils.http_request") as mock_http:
            store.set_tag("some_uuid", RunTag("t1", "abcd" * 1000))
            body = message_to_json(
                SetTag(run_uuid="some_uuid",
                       run_id="some_uuid",
                       key="t1",
                       value="abcd" * 1000))
            self._verify_requests(mock_http, creds, "runs/set-tag", "POST",
                                  body)

        with mock.patch("mlflow.utils.rest_utils.http_request") as mock_http:
            store.delete_tag("some_uuid", "t1")
            body = message_to_json(DeleteTag(run_id="some_uuid", key="t1"))
            self._verify_requests(mock_http, creds, "runs/delete-tag", "POST",
                                  body)

        with mock.patch("mlflow.utils.rest_utils.http_request") as mock_http:
            store.log_metric("u2", Metric("m1", 0.87, 12345, 3))
            body = message_to_json(
                LogMetric(run_uuid="u2",
                          run_id="u2",
                          key="m1",
                          value=0.87,
                          timestamp=12345,
                          step=3))
            self._verify_requests(mock_http, creds, "runs/log-metric", "POST",
                                  body)

        with mock.patch("mlflow.utils.rest_utils.http_request") as mock_http:
            metrics = [
                Metric("m1", 0.87, 12345, 0),
                Metric("m2", 0.49, 12345, -1),
                Metric("m3", 0.58, 12345, 2),
            ]
            params = [Param("p1", "p1val"), Param("p2", "p2val")]
            tags = [RunTag("t1", "t1val"), RunTag("t2", "t2val")]
            store.log_batch(run_id="u2",
                            metrics=metrics,
                            params=params,
                            tags=tags)
            metric_protos = [metric.to_proto() for metric in metrics]
            param_protos = [param.to_proto() for param in params]
            tag_protos = [tag.to_proto() for tag in tags]
            body = message_to_json(
                LogBatch(run_id="u2",
                         metrics=metric_protos,
                         params=param_protos,
                         tags=tag_protos))
            self._verify_requests(mock_http, creds, "runs/log-batch", "POST",
                                  body)

        with mock.patch("mlflow.utils.rest_utils.http_request") as mock_http:
            store.delete_run("u25")
            self._verify_requests(mock_http, creds, "runs/delete", "POST",
                                  message_to_json(DeleteRun(run_id="u25")))

        with mock.patch("mlflow.utils.rest_utils.http_request") as mock_http:
            store.restore_run("u76")
            self._verify_requests(mock_http, creds, "runs/restore", "POST",
                                  message_to_json(RestoreRun(run_id="u76")))

        with mock.patch("mlflow.utils.rest_utils.http_request") as mock_http:
            store.delete_experiment("0")
            self._verify_requests(
                mock_http,
                creds,
                "experiments/delete",
                "POST",
                message_to_json(DeleteExperiment(experiment_id="0")),
            )

        with mock.patch("mlflow.utils.rest_utils.http_request") as mock_http:
            store.restore_experiment("0")
            self._verify_requests(
                mock_http,
                creds,
                "experiments/restore",
                "POST",
                message_to_json(RestoreExperiment(experiment_id="0")),
            )

        with mock.patch("mlflow.utils.rest_utils.http_request") as mock_http:
            response = mock.MagicMock
            response.text = '{"runs": ["1a", "2b", "3c"], "next_page_token": "67890fghij"}'
            mock_http.return_value = response
            result = store.search_runs(
                ["0", "1"],
                "params.p1 = 'a'",
                ViewType.ACTIVE_ONLY,
                max_results=10,
                order_by=["a"],
                page_token="12345abcde",
            )

            expected_message = SearchRuns(
                experiment_ids=["0", "1"],
                filter="params.p1 = 'a'",
                run_view_type=ViewType.to_proto(ViewType.ACTIVE_ONLY),
                max_results=10,
                order_by=["a"],
                page_token="12345abcde",
            )
            self._verify_requests(mock_http, creds, "runs/search", "POST",
                                  message_to_json(expected_message))
            assert result.token == "67890fghij"

        with mock.patch("mlflow.utils.rest_utils.http_request") as mock_http:
            run_id = "run_id"
            m = Model(artifact_path="model/path",
                      run_id="run_id",
                      flavors={"tf": "flavor body"})
            result = store.record_logged_model("run_id", m)
            expected_message = LogModel(run_id=run_id, model_json=m.to_json())
            self._verify_requests(mock_http, creds, "runs/log-model", "POST",
                                  message_to_json(expected_message))
Пример #11
0
 def update_run_info(self, run_uuid, run_status, end_time):
     """ Updates the metadata of the specified run. """
     req_body = message_to_json(UpdateRun(run_uuid=run_uuid, status=run_status,
                                          end_time=end_time))
     response_proto = self._call_endpoint(UpdateRun, req_body)
     return RunInfo.from_proto(response_proto.run_info)
Пример #12
0
 def _get_run_artifact_root(self, run_id):
     json_body = message_to_json(GetRun(run_id=run_id))
     run_response = self._call_endpoint(MlflowService, GetRun, json_body)
     return run_response.run.info.artifact_uri
Пример #13
0
def _file_infos_to_json(file_infos):
    json_list = [
        message_to_json(file_info.to_proto()) for file_info in file_infos
    ]
    return "[" + ", ".join(json_list) + "]"
Пример #14
0
    def test_requestor(self, request):
        response = mock.MagicMock
        response.status_code = 200
        response.text = '{}'
        request.return_value = response

        creds = MlflowHostCreds('https://hello')
        store = RestStore(lambda: creds)

        user_name = "mock user"
        run_name = "rest run"
        source_name = "rest test"

        with mock.patch('mlflow.store.rest_store.http_request_safe') as mock_http, \
                mock.patch('mlflow.tracking.utils._get_store', return_value=store), \
                mock.patch('mlflow.tracking.client._get_user_id', return_value=user_name), \
                mock.patch('time.time', return_value=13579):
            with mlflow.start_run(experiment_id="43",
                                  run_name=run_name,
                                  source_name=source_name):
                cr_body = message_to_json(
                    CreateRun(experiment_id="43",
                              run_name='',
                              user_id=user_name,
                              source_type=SourceType.LOCAL,
                              source_name=source_name,
                              start_time=13579000,
                              tags=[
                                  ProtoRunTag(key='mlflow.source.name',
                                              value=source_name),
                                  ProtoRunTag(key='mlflow.source.type',
                                              value='LOCAL')
                              ]))
                st_body = message_to_json(
                    SetTag(run_uuid='',
                           run_id='',
                           key='mlflow.runName',
                           value=run_name))
                assert mock_http.call_count == 2
                exp_calls = [("runs/create", "POST", cr_body),
                             ("runs/set-tag", "POST", st_body)]
                self._verify_request_has_calls(mock_http, creds, exp_calls)

        with mock.patch(
                'mlflow.store.rest_store.http_request_safe') as mock_http:
            store.log_param("some_uuid", Param("k1", "v1"))
            body = message_to_json(
                LogParam(run_uuid="some_uuid",
                         run_id="some_uuid",
                         key="k1",
                         value="v1"))
            self._verify_requests(mock_http, creds, "runs/log-parameter",
                                  "POST", body)

        with mock.patch(
                'mlflow.store.rest_store.http_request_safe') as mock_http:
            store.set_tag("some_uuid", RunTag("t1", "abcd" * 1000))
            body = message_to_json(
                SetTag(run_uuid="some_uuid",
                       run_id="some_uuid",
                       key="t1",
                       value="abcd" * 1000))
            self._verify_requests(mock_http, creds, "runs/set-tag", "POST",
                                  body)

        with mock.patch(
                'mlflow.store.rest_store.http_request_safe') as mock_http:
            store.log_metric("u2", Metric("m1", 0.87, 12345, 3))
            body = message_to_json(
                LogMetric(run_uuid="u2",
                          run_id="u2",
                          key="m1",
                          value=0.87,
                          timestamp=12345,
                          step=3))
            self._verify_requests(mock_http, creds, "runs/log-metric", "POST",
                                  body)

        with mock.patch(
                'mlflow.store.rest_store.http_request_safe') as mock_http:
            metrics = [
                Metric("m1", 0.87, 12345, 0),
                Metric("m2", 0.49, 12345, -1),
                Metric("m3", 0.58, 12345, 2)
            ]
            params = [Param("p1", "p1val"), Param("p2", "p2val")]
            tags = [RunTag("t1", "t1val"), RunTag("t2", "t2val")]
            store.log_batch(run_id="u2",
                            metrics=metrics,
                            params=params,
                            tags=tags)
            metric_protos = [metric.to_proto() for metric in metrics]
            param_protos = [param.to_proto() for param in params]
            tag_protos = [tag.to_proto() for tag in tags]
            body = message_to_json(
                LogBatch(run_id="u2",
                         metrics=metric_protos,
                         params=param_protos,
                         tags=tag_protos))
            self._verify_requests(mock_http, creds, "runs/log-batch", "POST",
                                  body)

        with mock.patch(
                'mlflow.store.rest_store.http_request_safe') as mock_http:
            store.delete_run("u25")
            self._verify_requests(mock_http, creds, "runs/delete", "POST",
                                  message_to_json(DeleteRun(run_id="u25")))

        with mock.patch(
                'mlflow.store.rest_store.http_request_safe') as mock_http:
            store.restore_run("u76")
            self._verify_requests(mock_http, creds, "runs/restore", "POST",
                                  message_to_json(RestoreRun(run_id="u76")))

        with mock.patch(
                'mlflow.store.rest_store.http_request_safe') as mock_http:
            store.delete_experiment("0")
            self._verify_requests(
                mock_http, creds, "experiments/delete", "POST",
                message_to_json(DeleteExperiment(experiment_id="0")))

        with mock.patch(
                'mlflow.store.rest_store.http_request_safe') as mock_http:
            store.restore_experiment("0")
            self._verify_requests(
                mock_http, creds, "experiments/restore", "POST",
                message_to_json(RestoreExperiment(experiment_id="0")))
Пример #15
0
 def _verify_requests(self, http_request, endpoint, method, proto_message):
     print(http_request.call_args_list)
     json_body = message_to_json(proto_message)
     http_request.assert_any_call(
         **(self._args(self.creds, endpoint, method, json_body)))
Пример #16
0
 def _jsonify(obj):
     return json.loads(message_to_json(obj.to_proto()))
Пример #17
0
    def test_requestor(self, request):
        response = mock.MagicMock
        response.status_code = 200
        response.text = '{}'
        request.return_value = response

        creds = MlflowHostCreds('https://hello')
        store = RestStore(lambda: creds)

        user_name = "mock user"
        source_name = "rest test"

        source_name_patch = mock.patch(
            "mlflow.tracking.context._get_source_name",
            return_value=source_name)
        source_type_patch = mock.patch(
            "mlflow.tracking.context._get_source_type",
            return_value=SourceType.LOCAL)
        with mock.patch('mlflow.store.rest_store.http_request_safe') as mock_http, \
                mock.patch('mlflow.tracking.utils._get_store', return_value=store), \
                mock.patch('mlflow.tracking.context._get_user', return_value=user_name), \
                mock.patch('time.time', return_value=13579), \
                source_name_patch, source_type_patch:
            with mlflow.start_run(experiment_id="43"):
                cr_body = message_to_json(
                    CreateRun(experiment_id="43",
                              user_id=user_name,
                              start_time=13579000,
                              tags=[
                                  ProtoRunTag(key='mlflow.source.name',
                                              value=source_name),
                                  ProtoRunTag(key='mlflow.source.type',
                                              value='LOCAL'),
                                  ProtoRunTag(key='mlflow.user',
                                              value=user_name)
                              ]))
                expected_kwargs = self._args(creds, "runs/create", "POST",
                                             cr_body)

                assert mock_http.call_count == 1
                actual_kwargs = mock_http.call_args[1]

                # Test the passed tag values separately from the rest of the request
                # Tag order is inconsistent on Python 2 and 3, but the order does not matter
                expected_tags = expected_kwargs['json'].pop('tags')
                actual_tags = actual_kwargs['json'].pop('tags')
                assert (sorted(expected_tags,
                               key=lambda t: t['key']) == sorted(
                                   actual_tags, key=lambda t: t['key']))
                assert expected_kwargs == actual_kwargs

        with mock.patch(
                'mlflow.store.rest_store.http_request_safe') as mock_http:
            store.log_param("some_uuid", Param("k1", "v1"))
            body = message_to_json(
                LogParam(run_uuid="some_uuid",
                         run_id="some_uuid",
                         key="k1",
                         value="v1"))
            self._verify_requests(mock_http, creds, "runs/log-parameter",
                                  "POST", body)

        with mock.patch(
                'mlflow.store.rest_store.http_request_safe') as mock_http:
            store.set_tag("some_uuid", RunTag("t1", "abcd" * 1000))
            body = message_to_json(
                SetTag(run_uuid="some_uuid",
                       run_id="some_uuid",
                       key="t1",
                       value="abcd" * 1000))
            self._verify_requests(mock_http, creds, "runs/set-tag", "POST",
                                  body)

        with mock.patch(
                'mlflow.store.rest_store.http_request_safe') as mock_http:
            store.log_metric("u2", Metric("m1", 0.87, 12345, 3))
            body = message_to_json(
                LogMetric(run_uuid="u2",
                          run_id="u2",
                          key="m1",
                          value=0.87,
                          timestamp=12345,
                          step=3))
            self._verify_requests(mock_http, creds, "runs/log-metric", "POST",
                                  body)

        with mock.patch(
                'mlflow.store.rest_store.http_request_safe') as mock_http:
            metrics = [
                Metric("m1", 0.87, 12345, 0),
                Metric("m2", 0.49, 12345, -1),
                Metric("m3", 0.58, 12345, 2)
            ]
            params = [Param("p1", "p1val"), Param("p2", "p2val")]
            tags = [RunTag("t1", "t1val"), RunTag("t2", "t2val")]
            store.log_batch(run_id="u2",
                            metrics=metrics,
                            params=params,
                            tags=tags)
            metric_protos = [metric.to_proto() for metric in metrics]
            param_protos = [param.to_proto() for param in params]
            tag_protos = [tag.to_proto() for tag in tags]
            body = message_to_json(
                LogBatch(run_id="u2",
                         metrics=metric_protos,
                         params=param_protos,
                         tags=tag_protos))
            self._verify_requests(mock_http, creds, "runs/log-batch", "POST",
                                  body)

        with mock.patch(
                'mlflow.store.rest_store.http_request_safe') as mock_http:
            store.delete_run("u25")
            self._verify_requests(mock_http, creds, "runs/delete", "POST",
                                  message_to_json(DeleteRun(run_id="u25")))

        with mock.patch(
                'mlflow.store.rest_store.http_request_safe') as mock_http:
            store.restore_run("u76")
            self._verify_requests(mock_http, creds, "runs/restore", "POST",
                                  message_to_json(RestoreRun(run_id="u76")))

        with mock.patch(
                'mlflow.store.rest_store.http_request_safe') as mock_http:
            store.delete_experiment("0")
            self._verify_requests(
                mock_http, creds, "experiments/delete", "POST",
                message_to_json(DeleteExperiment(experiment_id="0")))

        with mock.patch(
                'mlflow.store.rest_store.http_request_safe') as mock_http:
            store.restore_experiment("0")
            self._verify_requests(
                mock_http, creds, "experiments/restore", "POST",
                message_to_json(RestoreExperiment(experiment_id="0")))