Exemplo n.º 1
0
def _set_experiment_tag():
    request_message = _get_request_message(SetExperimentTag())
    tag = ExperimentTag(request_message.key, request_message.value)
    _get_tracking_store().set_experiment_tag(request_message.experiment_id, tag)
    response_message = SetExperimentTag.Response()
    response = Response(mimetype='application/json')
    response.set_data(message_to_json(response_message))
    return response
Exemplo n.º 2
0
    def set_experiment_tag(self, experiment_id, tag):
        """
        Set a tag for the specified experiment

        :param experiment_id: String ID of the experiment
        :param tag: ExperimentRunTag instance to log
        """
        req_body = message_to_json(SetExperimentTag(
            experiment_id=experiment_id, key=tag.key, value=tag.value))
        self._call_endpoint(SetExperimentTag, req_body)
Exemplo n.º 3
0
    def test_requestor(self, request):
        response = mock.MagicMock
        response.status_code = 200
        response.text = '{}'
        request.return_value = response

        creds = MlflowHostCreds('https://hello')
        store = RestStore(lambda: creds)

        user_name = "mock user"
        source_name = "rest test"

        source_name_patch = mock.patch(
            "mlflow.tracking.context.default_context._get_source_name", return_value=source_name
        )
        source_type_patch = mock.patch(
            "mlflow.tracking.context.default_context._get_source_type",
            return_value=SourceType.LOCAL
        )
        with mock.patch('mlflow.store.rest_store.http_request') as mock_http, \
                mock.patch('mlflow.tracking.utils._get_store', return_value=store), \
                mock.patch('mlflow.tracking.context.default_context._get_user',
                           return_value=user_name), \
                mock.patch('time.time', return_value=13579), \
                source_name_patch, source_type_patch:
            with mlflow.start_run(experiment_id="43"):
                cr_body = message_to_json(CreateRun(experiment_id="43",
                                                    user_id=user_name, start_time=13579000,
                                                    tags=[ProtoRunTag(key='mlflow.source.name',
                                                                      value=source_name),
                                                          ProtoRunTag(key='mlflow.source.type',
                                                                      value='LOCAL'),
                                                          ProtoRunTag(key='mlflow.user',
                                                                      value=user_name)]))
                expected_kwargs = self._args(creds, "runs/create", "POST", cr_body)

                assert mock_http.call_count == 1
                actual_kwargs = mock_http.call_args[1]

                # Test the passed tag values separately from the rest of the request
                # Tag order is inconsistent on Python 2 and 3, but the order does not matter
                expected_tags = expected_kwargs['json'].pop('tags')
                actual_tags = actual_kwargs['json'].pop('tags')
                assert (
                    sorted(expected_tags, key=lambda t: t['key']) ==
                    sorted(actual_tags, key=lambda t: t['key'])
                )
                assert expected_kwargs == actual_kwargs

        with mock.patch('mlflow.store.rest_store.http_request') as mock_http:
            store.log_param("some_uuid", Param("k1", "v1"))
            body = message_to_json(LogParam(
                run_uuid="some_uuid", run_id="some_uuid", key="k1", value="v1"))
            self._verify_requests(mock_http, creds,
                                  "runs/log-parameter", "POST", body)

        with mock.patch('mlflow.store.rest_store.http_request') as mock_http:
            store.set_experiment_tag("some_id", ExperimentTag("t1", "abcd"*1000))
            body = message_to_json(SetExperimentTag(
                experiment_id="some_id",
                key="t1",
                value="abcd"*1000))
            self._verify_requests(mock_http, creds,
                                  "experiments/set-experiment-tag", "POST", body)

        with mock.patch('mlflow.store.rest_store.http_request') as mock_http:
            store.set_tag("some_uuid", RunTag("t1", "abcd"*1000))
            body = message_to_json(SetTag(
                run_uuid="some_uuid", run_id="some_uuid", key="t1", value="abcd"*1000))
            self._verify_requests(mock_http, creds,
                                  "runs/set-tag", "POST", body)

        with mock.patch('mlflow.store.rest_store.http_request') as mock_http:
            store.delete_tag("some_uuid", "t1")
            body = message_to_json(DeleteTag(run_id="some_uuid", key="t1"))
            self._verify_requests(mock_http, creds,
                                  "runs/delete-tag", "POST", body)

        with mock.patch('mlflow.store.rest_store.http_request') as mock_http:
            store.log_metric("u2", Metric("m1", 0.87, 12345, 3))
            body = message_to_json(LogMetric(
                run_uuid="u2", run_id="u2", key="m1", value=0.87, timestamp=12345, step=3))
            self._verify_requests(mock_http, creds,
                                  "runs/log-metric", "POST", body)

        with mock.patch('mlflow.store.rest_store.http_request') as mock_http:
            metrics = [Metric("m1", 0.87, 12345, 0), Metric("m2", 0.49, 12345, -1),
                       Metric("m3", 0.58, 12345, 2)]
            params = [Param("p1", "p1val"), Param("p2", "p2val")]
            tags = [RunTag("t1", "t1val"), RunTag("t2", "t2val")]
            store.log_batch(run_id="u2", metrics=metrics, params=params, tags=tags)
            metric_protos = [metric.to_proto() for metric in metrics]
            param_protos = [param.to_proto() for param in params]
            tag_protos = [tag.to_proto() for tag in tags]
            body = message_to_json(LogBatch(run_id="u2", metrics=metric_protos,
                                            params=param_protos, tags=tag_protos))
            self._verify_requests(mock_http, creds,
                                  "runs/log-batch", "POST", body)

        with mock.patch('mlflow.store.rest_store.http_request') as mock_http:
            store.delete_run("u25")
            self._verify_requests(mock_http, creds,
                                  "runs/delete", "POST",
                                  message_to_json(DeleteRun(run_id="u25")))

        with mock.patch('mlflow.store.rest_store.http_request') as mock_http:
            store.restore_run("u76")
            self._verify_requests(mock_http, creds,
                                  "runs/restore", "POST",
                                  message_to_json(RestoreRun(run_id="u76")))

        with mock.patch('mlflow.store.rest_store.http_request') as mock_http:
            store.delete_experiment("0")
            self._verify_requests(mock_http, creds,
                                  "experiments/delete", "POST",
                                  message_to_json(DeleteExperiment(experiment_id="0")))

        with mock.patch('mlflow.store.rest_store.http_request') as mock_http:
            store.restore_experiment("0")
            self._verify_requests(mock_http, creds,
                                  "experiments/restore", "POST",
                                  message_to_json(RestoreExperiment(experiment_id="0")))

        with mock.patch('mlflow.store.rest_store.http_request') as mock_http:
            response = mock.MagicMock
            response.text = '{"runs": ["1a", "2b", "3c"], "next_page_token": "67890fghij"}'
            mock_http.return_value = response
            result = store.search_runs(["0", "1"], "params.p1 = 'a'", ViewType.ACTIVE_ONLY,
                                       max_results=10, order_by=["a"], page_token="12345abcde")

            expected_message = SearchRuns(experiment_ids=["0", "1"], filter="params.p1 = 'a'",
                                          run_view_type=ViewType.to_proto(ViewType.ACTIVE_ONLY),
                                          max_results=10, order_by=["a"], page_token="12345abcde")
            self._verify_requests(mock_http, creds,
                                  "runs/search", "POST",
                                  message_to_json(expected_message))
            assert result.token == "67890fghij"