Пример #1
0
def test_model_save_load():
    m = Model(artifact_path="some/path",
              run_id="123",
              flavors={
                  "flavor1": {
                      "a": 1,
                      "b": 2
                  },
                  "flavor2": {
                      "x": 1,
                      "y": 2
                  },
              },
              signature=ModelSignature(
                  inputs=Schema(
                      [ColSpec("integer", "x"),
                       ColSpec("integer", "y")]),
                  outputs=Schema([ColSpec(name=None, type="double")])),
              saved_input_example_info={
                  "x": 1,
                  "y": 2
              })
    assert m.get_input_schema() == m.signature.inputs
    assert m.get_output_schema() == m.signature.outputs
    x = Model(artifact_path="some/other/path", run_id="1234")
    assert x.get_input_schema() is None
    assert x.get_output_schema() is None

    n = Model(artifact_path="some/path",
              run_id="123",
              flavors={
                  "flavor1": {
                      "a": 1,
                      "b": 2
                  },
                  "flavor2": {
                      "x": 1,
                      "y": 2
                  },
              },
              signature=ModelSignature(
                  inputs=Schema(
                      [ColSpec("integer", "x"),
                       ColSpec("integer", "y")]),
                  outputs=Schema([ColSpec(name=None, type="double")])),
              saved_input_example_info={
                  "x": 1,
                  "y": 2
              })
    n.utc_time_created = m.utc_time_created
    assert m == n
    n.signature = None
    assert m != n
    with TempDir() as tmp:
        m.save(tmp.path("model"))
        o = Model.load(tmp.path("model"))
    assert m == o
    assert m.to_json() == o.to_json()
    assert m.to_yaml() == o.to_yaml()
    def test_requestor(self, request):
        response = mock.MagicMock
        response.status_code = 200
        response.text = "{}"
        request.return_value = response

        creds = MlflowHostCreds("https://hello")
        store = RestStore(lambda: creds)

        user_name = "mock user"
        source_name = "rest test"

        source_name_patch = mock.patch(
            "mlflow.tracking.context.default_context._get_source_name",
            return_value=source_name)
        source_type_patch = mock.patch(
            "mlflow.tracking.context.default_context._get_source_type",
            return_value=SourceType.LOCAL,
        )
        with mock.patch(
                "mlflow.utils.rest_utils.http_request"
        ) as mock_http, mock.patch(
                "mlflow.tracking._tracking_service.utils._get_store",
                return_value=store), mock.patch(
                    "mlflow.tracking.context.default_context._get_user",
                    return_value=user_name), mock.patch(
                        "time.time", return_value=13579
                    ), source_name_patch, source_type_patch:
            with mlflow.start_run(experiment_id="43"):
                cr_body = message_to_json(
                    CreateRun(
                        experiment_id="43",
                        user_id=user_name,
                        start_time=13579000,
                        tags=[
                            ProtoRunTag(key="mlflow.source.name",
                                        value=source_name),
                            ProtoRunTag(key="mlflow.source.type",
                                        value="LOCAL"),
                            ProtoRunTag(key="mlflow.user", value=user_name),
                        ],
                    ))
                expected_kwargs = self._args(creds, "runs/create", "POST",
                                             cr_body)

                assert mock_http.call_count == 1
                actual_kwargs = mock_http.call_args[1]

                # Test the passed tag values separately from the rest of the request
                # Tag order is inconsistent on Python 2 and 3, but the order does not matter
                expected_tags = expected_kwargs["json"].pop("tags")
                actual_tags = actual_kwargs["json"].pop("tags")
                assert sorted(expected_tags, key=lambda t: t["key"]) == sorted(
                    actual_tags, key=lambda t: t["key"])
                assert expected_kwargs == actual_kwargs

        with mock.patch("mlflow.utils.rest_utils.http_request") as mock_http:
            store.log_param("some_uuid", Param("k1", "v1"))
            body = message_to_json(
                LogParam(run_uuid="some_uuid",
                         run_id="some_uuid",
                         key="k1",
                         value="v1"))
            self._verify_requests(mock_http, creds, "runs/log-parameter",
                                  "POST", body)

        with mock.patch("mlflow.utils.rest_utils.http_request") as mock_http:
            store.set_experiment_tag("some_id",
                                     ExperimentTag("t1", "abcd" * 1000))
            body = message_to_json(
                SetExperimentTag(experiment_id="some_id",
                                 key="t1",
                                 value="abcd" * 1000))
            self._verify_requests(mock_http, creds,
                                  "experiments/set-experiment-tag", "POST",
                                  body)

        with mock.patch("mlflow.utils.rest_utils.http_request") as mock_http:
            store.set_tag("some_uuid", RunTag("t1", "abcd" * 1000))
            body = message_to_json(
                SetTag(run_uuid="some_uuid",
                       run_id="some_uuid",
                       key="t1",
                       value="abcd" * 1000))
            self._verify_requests(mock_http, creds, "runs/set-tag", "POST",
                                  body)

        with mock.patch("mlflow.utils.rest_utils.http_request") as mock_http:
            store.delete_tag("some_uuid", "t1")
            body = message_to_json(DeleteTag(run_id="some_uuid", key="t1"))
            self._verify_requests(mock_http, creds, "runs/delete-tag", "POST",
                                  body)

        with mock.patch("mlflow.utils.rest_utils.http_request") as mock_http:
            store.log_metric("u2", Metric("m1", 0.87, 12345, 3))
            body = message_to_json(
                LogMetric(run_uuid="u2",
                          run_id="u2",
                          key="m1",
                          value=0.87,
                          timestamp=12345,
                          step=3))
            self._verify_requests(mock_http, creds, "runs/log-metric", "POST",
                                  body)

        with mock.patch("mlflow.utils.rest_utils.http_request") as mock_http:
            metrics = [
                Metric("m1", 0.87, 12345, 0),
                Metric("m2", 0.49, 12345, -1),
                Metric("m3", 0.58, 12345, 2),
            ]
            params = [Param("p1", "p1val"), Param("p2", "p2val")]
            tags = [RunTag("t1", "t1val"), RunTag("t2", "t2val")]
            store.log_batch(run_id="u2",
                            metrics=metrics,
                            params=params,
                            tags=tags)
            metric_protos = [metric.to_proto() for metric in metrics]
            param_protos = [param.to_proto() for param in params]
            tag_protos = [tag.to_proto() for tag in tags]
            body = message_to_json(
                LogBatch(run_id="u2",
                         metrics=metric_protos,
                         params=param_protos,
                         tags=tag_protos))
            self._verify_requests(mock_http, creds, "runs/log-batch", "POST",
                                  body)

        with mock.patch("mlflow.utils.rest_utils.http_request") as mock_http:
            store.delete_run("u25")
            self._verify_requests(mock_http, creds, "runs/delete", "POST",
                                  message_to_json(DeleteRun(run_id="u25")))

        with mock.patch("mlflow.utils.rest_utils.http_request") as mock_http:
            store.restore_run("u76")
            self._verify_requests(mock_http, creds, "runs/restore", "POST",
                                  message_to_json(RestoreRun(run_id="u76")))

        with mock.patch("mlflow.utils.rest_utils.http_request") as mock_http:
            store.delete_experiment("0")
            self._verify_requests(
                mock_http,
                creds,
                "experiments/delete",
                "POST",
                message_to_json(DeleteExperiment(experiment_id="0")),
            )

        with mock.patch("mlflow.utils.rest_utils.http_request") as mock_http:
            store.restore_experiment("0")
            self._verify_requests(
                mock_http,
                creds,
                "experiments/restore",
                "POST",
                message_to_json(RestoreExperiment(experiment_id="0")),
            )

        with mock.patch("mlflow.utils.rest_utils.http_request") as mock_http:
            response = mock.MagicMock
            response.text = '{"runs": ["1a", "2b", "3c"], "next_page_token": "67890fghij"}'
            mock_http.return_value = response
            result = store.search_runs(
                ["0", "1"],
                "params.p1 = 'a'",
                ViewType.ACTIVE_ONLY,
                max_results=10,
                order_by=["a"],
                page_token="12345abcde",
            )

            expected_message = SearchRuns(
                experiment_ids=["0", "1"],
                filter="params.p1 = 'a'",
                run_view_type=ViewType.to_proto(ViewType.ACTIVE_ONLY),
                max_results=10,
                order_by=["a"],
                page_token="12345abcde",
            )
            self._verify_requests(mock_http, creds, "runs/search", "POST",
                                  message_to_json(expected_message))
            assert result.token == "67890fghij"

        with mock.patch("mlflow.utils.rest_utils.http_request") as mock_http:
            run_id = "run_id"
            m = Model(artifact_path="model/path",
                      run_id="run_id",
                      flavors={"tf": "flavor body"})
            result = store.record_logged_model("run_id", m)
            expected_message = LogModel(run_id=run_id, model_json=m.to_json())
            self._verify_requests(mock_http, creds, "runs/log-model", "POST",
                                  message_to_json(expected_message))