示例#1
0
def _log_model():
    request_message = _get_request_message(LogModel())
    try:
        model = json.loads(request_message.model_json)
    except:  # NB: can not be more specific here due to python2 compatibility
        raise MlflowException(
            "Malformed model info. \n {} \n is not a valid JSON.".format(
                request_message.model_json),
            error_code=INVALID_PARAMETER_VALUE,
        )

    missing_fields = set(
        ("artifact_path", "flavors", "utc_time_created", "run_id")) - set(
            model.keys())
    if missing_fields:
        raise MlflowException(
            "Model json is missing mandatory fields: {}".format(
                missing_fields),
            error_code=INVALID_PARAMETER_VALUE,
        )
    _get_tracking_store().record_logged_model(
        run_id=request_message.run_id, mlflow_model=Model.from_dict(model))
    response_message = LogModel.Response()
    response = Response(mimetype="application/json")
    response.set_data(message_to_json(response_message))
    return response
示例#2
0
 def record_logged_model(self, run_id, mlflow_model):
     req_body = message_to_json(LogModel(run_id=run_id, model_json=mlflow_model.to_json()))
     self._call_endpoint(LogModel, req_body)
    def test_requestor(self, request):
        response = mock.MagicMock
        response.status_code = 200
        response.text = "{}"
        request.return_value = response

        creds = MlflowHostCreds("https://hello")
        store = RestStore(lambda: creds)

        user_name = "mock user"
        source_name = "rest test"

        source_name_patch = mock.patch(
            "mlflow.tracking.context.default_context._get_source_name",
            return_value=source_name)
        source_type_patch = mock.patch(
            "mlflow.tracking.context.default_context._get_source_type",
            return_value=SourceType.LOCAL,
        )
        with mock.patch(
                "mlflow.utils.rest_utils.http_request"
        ) as mock_http, mock.patch(
                "mlflow.tracking._tracking_service.utils._get_store",
                return_value=store), mock.patch(
                    "mlflow.tracking.context.default_context._get_user",
                    return_value=user_name), mock.patch(
                        "time.time", return_value=13579
                    ), source_name_patch, source_type_patch:
            with mlflow.start_run(experiment_id="43"):
                cr_body = message_to_json(
                    CreateRun(
                        experiment_id="43",
                        user_id=user_name,
                        start_time=13579000,
                        tags=[
                            ProtoRunTag(key="mlflow.source.name",
                                        value=source_name),
                            ProtoRunTag(key="mlflow.source.type",
                                        value="LOCAL"),
                            ProtoRunTag(key="mlflow.user", value=user_name),
                        ],
                    ))
                expected_kwargs = self._args(creds, "runs/create", "POST",
                                             cr_body)

                assert mock_http.call_count == 1
                actual_kwargs = mock_http.call_args[1]

                # Test the passed tag values separately from the rest of the request
                # Tag order is inconsistent on Python 2 and 3, but the order does not matter
                expected_tags = expected_kwargs["json"].pop("tags")
                actual_tags = actual_kwargs["json"].pop("tags")
                assert sorted(expected_tags, key=lambda t: t["key"]) == sorted(
                    actual_tags, key=lambda t: t["key"])
                assert expected_kwargs == actual_kwargs

        with mock.patch("mlflow.utils.rest_utils.http_request") as mock_http:
            store.log_param("some_uuid", Param("k1", "v1"))
            body = message_to_json(
                LogParam(run_uuid="some_uuid",
                         run_id="some_uuid",
                         key="k1",
                         value="v1"))
            self._verify_requests(mock_http, creds, "runs/log-parameter",
                                  "POST", body)

        with mock.patch("mlflow.utils.rest_utils.http_request") as mock_http:
            store.set_experiment_tag("some_id",
                                     ExperimentTag("t1", "abcd" * 1000))
            body = message_to_json(
                SetExperimentTag(experiment_id="some_id",
                                 key="t1",
                                 value="abcd" * 1000))
            self._verify_requests(mock_http, creds,
                                  "experiments/set-experiment-tag", "POST",
                                  body)

        with mock.patch("mlflow.utils.rest_utils.http_request") as mock_http:
            store.set_tag("some_uuid", RunTag("t1", "abcd" * 1000))
            body = message_to_json(
                SetTag(run_uuid="some_uuid",
                       run_id="some_uuid",
                       key="t1",
                       value="abcd" * 1000))
            self._verify_requests(mock_http, creds, "runs/set-tag", "POST",
                                  body)

        with mock.patch("mlflow.utils.rest_utils.http_request") as mock_http:
            store.delete_tag("some_uuid", "t1")
            body = message_to_json(DeleteTag(run_id="some_uuid", key="t1"))
            self._verify_requests(mock_http, creds, "runs/delete-tag", "POST",
                                  body)

        with mock.patch("mlflow.utils.rest_utils.http_request") as mock_http:
            store.log_metric("u2", Metric("m1", 0.87, 12345, 3))
            body = message_to_json(
                LogMetric(run_uuid="u2",
                          run_id="u2",
                          key="m1",
                          value=0.87,
                          timestamp=12345,
                          step=3))
            self._verify_requests(mock_http, creds, "runs/log-metric", "POST",
                                  body)

        with mock.patch("mlflow.utils.rest_utils.http_request") as mock_http:
            metrics = [
                Metric("m1", 0.87, 12345, 0),
                Metric("m2", 0.49, 12345, -1),
                Metric("m3", 0.58, 12345, 2),
            ]
            params = [Param("p1", "p1val"), Param("p2", "p2val")]
            tags = [RunTag("t1", "t1val"), RunTag("t2", "t2val")]
            store.log_batch(run_id="u2",
                            metrics=metrics,
                            params=params,
                            tags=tags)
            metric_protos = [metric.to_proto() for metric in metrics]
            param_protos = [param.to_proto() for param in params]
            tag_protos = [tag.to_proto() for tag in tags]
            body = message_to_json(
                LogBatch(run_id="u2",
                         metrics=metric_protos,
                         params=param_protos,
                         tags=tag_protos))
            self._verify_requests(mock_http, creds, "runs/log-batch", "POST",
                                  body)

        with mock.patch("mlflow.utils.rest_utils.http_request") as mock_http:
            store.delete_run("u25")
            self._verify_requests(mock_http, creds, "runs/delete", "POST",
                                  message_to_json(DeleteRun(run_id="u25")))

        with mock.patch("mlflow.utils.rest_utils.http_request") as mock_http:
            store.restore_run("u76")
            self._verify_requests(mock_http, creds, "runs/restore", "POST",
                                  message_to_json(RestoreRun(run_id="u76")))

        with mock.patch("mlflow.utils.rest_utils.http_request") as mock_http:
            store.delete_experiment("0")
            self._verify_requests(
                mock_http,
                creds,
                "experiments/delete",
                "POST",
                message_to_json(DeleteExperiment(experiment_id="0")),
            )

        with mock.patch("mlflow.utils.rest_utils.http_request") as mock_http:
            store.restore_experiment("0")
            self._verify_requests(
                mock_http,
                creds,
                "experiments/restore",
                "POST",
                message_to_json(RestoreExperiment(experiment_id="0")),
            )

        with mock.patch("mlflow.utils.rest_utils.http_request") as mock_http:
            response = mock.MagicMock
            response.text = '{"runs": ["1a", "2b", "3c"], "next_page_token": "67890fghij"}'
            mock_http.return_value = response
            result = store.search_runs(
                ["0", "1"],
                "params.p1 = 'a'",
                ViewType.ACTIVE_ONLY,
                max_results=10,
                order_by=["a"],
                page_token="12345abcde",
            )

            expected_message = SearchRuns(
                experiment_ids=["0", "1"],
                filter="params.p1 = 'a'",
                run_view_type=ViewType.to_proto(ViewType.ACTIVE_ONLY),
                max_results=10,
                order_by=["a"],
                page_token="12345abcde",
            )
            self._verify_requests(mock_http, creds, "runs/search", "POST",
                                  message_to_json(expected_message))
            assert result.token == "67890fghij"

        with mock.patch("mlflow.utils.rest_utils.http_request") as mock_http:
            run_id = "run_id"
            m = Model(artifact_path="model/path",
                      run_id="run_id",
                      flavors={"tf": "flavor body"})
            result = store.record_logged_model("run_id", m)
            expected_message = LogModel(run_id=run_id, model_json=m.to_json())
            self._verify_requests(mock_http, creds, "runs/log-model", "POST",
                                  message_to_json(expected_message))