Exemplo n.º 1
0
def test_search_runs_data():
    import numpy as np
    import pandas as pd

    runs = [
        create_run(
            metrics=[Metric("mse", 0.2, 0, 0)],
            params=[Param("param", "value")],
            tags=[RunTag("tag", "value")],
            start=1564675200000,
            end=1564683035000,
        ),
        create_run(
            metrics=[Metric("mse", 0.6, 0, 0),
                     Metric("loss", 1.2, 0, 5)],
            params=[Param("param2", "val"),
                    Param("k", "v")],
            tags=[RunTag("tag2", "v2")],
            start=1564765200000,
            end=1564783200000,
        ),
    ]
    with mock.patch("mlflow.tracking.fluent._paginate", return_value=runs):
        pdf = search_runs()
        data = {
            "status": [RunStatus.FINISHED] * 2,
            "artifact_uri": [None] * 2,
            "run_id": [""] * 2,
            "experiment_id": [""] * 2,
            "metrics.mse": [0.2, 0.6],
            "metrics.loss": [np.nan, 1.2],
            "params.param": ["value", None],
            "params.param2": [None, "val"],
            "params.k": [None, "v"],
            "tags.tag": ["value", None],
            "tags.tag2": [None, "v2"],
            "start_time": [
                pd.to_datetime(1564675200000, unit="ms", utc=True),
                pd.to_datetime(1564765200000, unit="ms", utc=True),
            ],
            "end_time": [
                pd.to_datetime(1564683035000, unit="ms", utc=True),
                pd.to_datetime(1564783200000, unit="ms", utc=True),
            ],
        }
        validate_search_runs(pdf, data, "pandas")
Exemplo n.º 2
0
 def _get_metric_from_file(parent_path, metric_name):
     _validate_metric_name(metric_name)
     metric_data = read_file_lines(parent_path, metric_name)
     if len(metric_data) == 0:
         raise Exception("Metric '%s' is malformed. No data found." % metric_name)
     last_line = metric_data[-1]
     timestamp, val = last_line.strip().split(" ")
     return Metric(metric_name, float(val), int(timestamp))
Exemplo n.º 3
0
    def test_creation_and_hydration(self):
        key = random_str()
        value = 10000
        ts = int(time.time())

        metric = Metric(key, value, ts)
        self._check(metric, key, value, ts)

        as_dict = {"key": key, "value": value, "timestamp": ts}
        self.assertEqual(dict(metric), as_dict)

        proto = metric.to_proto()
        metric2 = metric.from_proto(proto)
        self._check(metric2, key, value, ts)

        metric3 = Metric.from_dictionary(as_dict)
        self._check(metric3, key, value, ts)
Exemplo n.º 4
0
 def test_log_batch_same_metric_repeated_multiple_reqs(self):
     fs = FileStore(self.test_root)
     run = self._create_run(fs)
     metric0 = Metric(key="metric-key", value=1, timestamp=2, step=0)
     metric1 = Metric(key="metric-key", value=2, timestamp=3, step=0)
     fs.log_batch(run.info.run_id, params=[], metrics=[metric0], tags=[])
     self._verify_logged(fs,
                         run.info.run_id,
                         params=[],
                         metrics=[metric0],
                         tags=[])
     fs.log_batch(run.info.run_id, params=[], metrics=[metric1], tags=[])
     self._verify_logged(fs,
                         run.info.run_id,
                         params=[],
                         metrics=[metric0, metric1],
                         tags=[])
Exemplo n.º 5
0
def _log_metric():
    request_message = _get_request_message(LogMetric())
    metric = Metric(request_message.key, request_message.value, request_message.timestamp)
    _get_store().log_metric(request_message.run_uuid, metric)
    response_message = LogMetric.Response()
    response = Response(mimetype='application/json')
    response.set_data(message_to_json(response_message))
    return response
Exemplo n.º 6
0
 def log_metric(self, run_id, key, value, timestamp=None):
     """Logs a metric against the given run id. If timestamp is not provided, we will
     use the current timestamp.
     """
     _validate_metric_name(key)
     timestamp = timestamp if timestamp is not None else int(time.time())
     metric = Metric(key, value, timestamp)
     self.store.log_metric(run_id, metric)
Exemplo n.º 7
0
def _log_specialized_estimator_content(fitted_estimator, run_id, fit_args,
                                       fit_kwargs):
    import sklearn

    mlflow_client = MlflowClient()
    name_metric_dict = {}
    try:
        if sklearn.base.is_classifier(fitted_estimator):
            name_metric_dict = _get_classifier_metrics(fitted_estimator,
                                                       fit_args, fit_kwargs)

        elif sklearn.base.is_regressor(fitted_estimator):
            name_metric_dict = _get_regressor_metrics(fitted_estimator,
                                                      fit_args, fit_kwargs)
    except Exception as err:  # pylint: disable=broad-except
        msg = ("Failed to autolog metrics for " +
               fitted_estimator.__class__.__name__ + ". Logging error: " +
               str(err))
        _logger.warning(msg)
    else:
        # batch log all metrics
        try_mlflow_log(
            mlflow_client.log_batch,
            run_id,
            metrics=[
                Metric(key=str(key),
                       value=value,
                       timestamp=int(time.time() * 1000),
                       step=0) for key, value in name_metric_dict.items()
            ],
        )

    if sklearn.base.is_classifier(fitted_estimator):
        try:
            artifacts = _get_classifier_artifacts(fitted_estimator, fit_args,
                                                  fit_kwargs)
        except Exception as e:  # pylint: disable=broad-except
            msg = ("Failed to autolog artifacts for " +
                   fitted_estimator.__class__.__name__ + ". Logging error: " +
                   str(e))
            _logger.warning(msg)
            return

        with TempDir() as tmp_dir:
            for artifact in artifacts:
                try:
                    display = artifact.function(**artifact.arguments)
                    display.ax_.set_title(artifact.title)
                    filepath = tmp_dir.path("{}.png".format(artifact.name))
                    display.figure_.savefig(filepath)
                    import matplotlib.pyplot as plt

                    plt.close(display.figure_)
                except Exception as e:  # pylint: disable=broad-except
                    _log_warning_for_artifacts(artifact.name,
                                               artifact.function, e)

            try_mlflow_log(mlflow_client.log_artifacts, run_id, tmp_dir.path())
Exemplo n.º 8
0
def test_validate_batch_log_data():
    metrics_with_bad_key = [Metric("good-metric-key", 1.0, 0, 0),
                            Metric("super-long-bad-key" * 1000, 4.0, 0, 0)]
    metrics_with_bad_val = [Metric("good-metric-key", "not-a-double-val", 0, 0)]
    metrics_with_bad_ts = [Metric("good-metric-key", 1.0, "not-a-timestamp", 0)]
    metrics_with_neg_ts = [Metric("good-metric-key", 1.0, -123, 0)]
    metrics_with_bad_step = [Metric("good-metric-key", 1.0, 0, "not-a-step")]
    params_with_bad_key = [Param("good-param-key", "hi"),
                           Param("super-long-bad-key" * 1000, "but-good-val")]
    params_with_bad_val = [Param("good-param-key", "hi"),
                           Param("another-good-key", "but-bad-val" * 1000)]
    tags_with_bad_key = [RunTag("good-tag-key", "hi"),
                         RunTag("super-long-bad-key" * 1000, "but-good-val")]
    tags_with_bad_val = [RunTag("good-tag-key", "hi"),
                         RunTag("another-good-key", "but-bad-val" * 1000)]
    bad_kwargs = {
        "metrics": [metrics_with_bad_key, metrics_with_bad_val, metrics_with_bad_ts,
                    metrics_with_neg_ts, metrics_with_bad_step],
        "params": [params_with_bad_key, params_with_bad_val],
        "tags": [tags_with_bad_key, tags_with_bad_val],
    }
    good_kwargs = {"metrics": [], "params": [], "tags": []}
    for arg_name, arg_values in bad_kwargs.items():
        for arg_value in arg_values:
            final_kwargs = copy.deepcopy(good_kwargs)
            final_kwargs[arg_name] = arg_value
            with pytest.raises(MlflowException):
                _validate_batch_log_data(**final_kwargs)
    # Test that we don't reject entities within the limit
    _validate_batch_log_data(
        metrics=[Metric("metric-key", 1.0, 0, 0)], params=[Param("param-key", "param-val")],
        tags=[RunTag("tag-key", "tag-val")])
Exemplo n.º 9
0
def test_log_batch_handler_success(mock_get_request_message, tmpdir):
    # Test success cases for the LogBatch API
    def _test_log_batch_helper_success(
            metric_entities, param_entities, tag_entities,
            expected_metrics=None, expected_params=None, expected_tags=None):
        """
        Simulates a LogBatch API request using the provided metrics/params/tags, asserting that it
        succeeds & that the backing store contains either the set of expected metrics/params/tags
        (if provided) or, by default, the metrics/params/tags used in the API request.
        """
        with mlflow.start_run() as active_run:
            run_id = active_run.info.run_uuid
            mock_get_request_message.return_value = LogBatch(
                run_id=run_id,
                metrics=[m.to_proto() for m in metric_entities],
                params=[p.to_proto() for p in param_entities],
                tags=[t.to_proto() for t in tag_entities])
            response = _log_batch()
            print(response, response.get_data())
            assert response.status_code == 200
            json_response = json.loads(response.get_data())
            assert json_response == {}
            _assert_logged_entities(
                run_id, expected_metrics or metric_entities, expected_params or param_entities,
                expected_tags or tag_entities)

    store = FileStore(tmpdir.strpath)
    with mock.patch('mlflow.tracking.utils._get_store', return_value=store):
        mlflow.set_experiment("log-batch-experiment")
        # Log an empty payload
        _test_log_batch_helper_success([], [], [])
        # Log multiple metrics/params/tags
        _test_log_batch_helper_success(
            metric_entities=[Metric(key="m-key", value=3.2 * i, timestamp=i) for i in range(3)],
            param_entities=[Param(key="p-key-%s" % i, value="p-val-%s" % i) for i in range(4)],
            tag_entities=[RunTag(key="t-key-%s" % i, value="t-val-%s" % i) for i in range(5)])
        # Log metrics with the same key
        _test_log_batch_helper_success(
            metric_entities=[Metric(key="m-key", value=3.2 * i, timestamp=3) for i in range(3)],
            param_entities=[], tag_entities=[])
        # Log tags with the same key, verify the last one gets written
        same_key_tags = [RunTag(key="t-key", value="t-val-%s" % i) for i in range(5)]
        _test_log_batch_helper_success(
            metric_entities=[], param_entities=[], tag_entities=same_key_tags,
            expected_tags=[same_key_tags[-1]])
Exemplo n.º 10
0
def test_search_runs_data():
    runs = [
        create_run(metrics=[Metric("mse", 0.2, 0, 0)],
                   params=[Param("param", "value")],
                   tags=[RunTag("tag", "value")],
                   start=1564675200000,
                   end=1564683035000),
        create_run(
            metrics=[Metric("mse", 0.6, 0, 0),
                     Metric("loss", 1.2, 0, 5)],
            params=[Param("param2", "val"),
                    Param("k", "v")],
            tags=[RunTag("tag2", "v2")],
            start=1564765200000,
            end=1564783200000)
    ]
    with mock.patch('mlflow.tracking.fluent._get_paginated_runs',
                    return_value=runs):
        pdf = search_runs()
        data = {
            'status': [RunStatus.FINISHED] * 2,
            'artifact_uri': [None] * 2,
            'run_id': [''] * 2,
            'experiment_id': [""] * 2,
            'metrics.mse': [0.2, 0.6],
            'metrics.loss': [np.nan, 1.2],
            'params.param': ["value", None],
            'params.param2': [None, "val"],
            'params.k': [None, "v"],
            'tags.tag': ["value", None],
            'tags.tag2': [None, "v2"],
            'start_time': [
                pd.to_datetime(1564675200000, unit="ms", utc=True),
                pd.to_datetime(1564765200000, unit="ms", utc=True)
            ],
            'end_time': [
                pd.to_datetime(1564683035000, unit="ms", utc=True),
                pd.to_datetime(1564783200000, unit="ms", utc=True)
            ]
        }
        expected_df = pd.DataFrame(data)
        pd.testing.assert_frame_equal(pdf,
                                      expected_df,
                                      check_like=True,
                                      check_frame_type=False)
Exemplo n.º 11
0
 def _create():
     metrics = [Metric(key=random_str(10),
                       value=random_int(0, 1000),
                       timestamp=int(time.time()) + random_int(-1e4, 1e4),
                       step=random_int())]
     params = [Param(random_str(10), random_str(random_int(10, 35))) for _ in range(10)]  # noqa
     tags = [RunTag(random_str(10), random_str(random_int(10, 35))) for _ in range(10)]  # noqa
     rd = RunData(metrics=metrics, params=params, tags=tags)
     return rd, metrics, params, tags
Exemplo n.º 12
0
def copy_run_data_batch(src_client, src_run, log_source_info, dst_client, dst_run_id):
    import time
    from mlflow.entities import Metric, Param, RunTag
    now = int(time.time()+.5)
    params = [ Param(k,v) for k,v in src_run.data.params.items() ]
    metrics = [ Metric(k,v,now,0) for k,v in src_run.data.metrics.items() ] # TODO: check timestamp and step semantics
    tags = utils.create_tags(src_client, src_run, log_source_info)
    tags = [ RunTag(k,v) for k,v in tags.items() ]
    dst_client.log_batch(dst_run_id, metrics, params, tags)
Exemplo n.º 13
0
def _add_to_queue(key, value, step, time, run_id):
    """
    Add a metric to the metric queue. Flush the queue if it exceeds
    max size.
    """
    met = Metric(key=key, value=value, timestamp=time, step=step)
    _metric_queue.append((run_id, met))
    if len(_metric_queue) > _MAX_METRIC_QUEUE_SIZE:
        _flush_queue()
Exemplo n.º 14
0
def test_creation_and_hydration():
    key = random_str()
    value = 10000
    ts = int(1000 * time.time())
    step = random_int()

    metric = Metric(key, value, ts, step)
    _check(metric, key, value, ts, step)

    as_dict = {"key": key, "value": value, "timestamp": ts, "step": step}
    assert dict(metric) == as_dict

    proto = metric.to_proto()
    metric2 = metric.from_proto(proto)
    _check(metric2, key, value, ts, step)

    metric3 = Metric.from_dictionary(as_dict)
    _check(metric3, key, value, ts, step)
Exemplo n.º 15
0
 def log_metric(self, run_id, key, value, timestamp=None):
     """
     Log a metric against the run ID. If timestamp is not provided, uses
     the current timestamp.
     """
     timestamp = timestamp if timestamp is not None else int(time.time())
     _validate_metric(key, value, timestamp)
     metric = Metric(key, value, timestamp)
     self.store.log_metric(run_id, metric)
Exemplo n.º 16
0
 def test_weird_metric_names(self):
     WEIRD_METRIC_NAME = "this is/a weird/but valid metric"
     fs = FileStore(self.test_root)
     run_uuid = self.exp_data[0]["runs"][0]
     fs.log_metric(run_uuid, Metric(WEIRD_METRIC_NAME, 10, 1234))
     metric = fs.get_metric(run_uuid, WEIRD_METRIC_NAME)
     assert metric.key == WEIRD_METRIC_NAME
     assert metric.value == 10
     assert metric.timestamp == 1234
Exemplo n.º 17
0
def import_run_data_batch(run_dct, run_id):
    import time
    from mlflow.entities import Metric, Param, RunTag
    now = int(time.time() + .5)
    params = [Param(k, v) for k, v in run_dct['params'].items()]
    metrics = [Metric(k, v, now, 0) for k, v in run_dct['metrics'].items()
               ]  # TODO: check timestamp and step semantics
    tags = [RunTag(k, v) for k, v in run_dct['tags'].items()]
    client.log_batch(run_id, metrics, params, tags)
Exemplo n.º 18
0
def test_correct_sorting(order_bys, matching_runs):
    runs = [
        Run(run_info=RunInfo(run_uuid="9",
                             run_id="9",
                             experiment_id=0,
                             user_id="user-id",
                             status=RunStatus.to_string(RunStatus.FAILED),
                             start_time=0,
                             end_time=1,
                             lifecycle_stage=LifecycleStage.ACTIVE),
            run_data=RunData(metrics=[Metric("key1", 121, 1, 0)],
                             params=[Param("my_param", "A")],
                             tags=[])),
        Run(run_info=RunInfo(run_uuid="8",
                             run_id="8",
                             experiment_id=0,
                             user_id="user-id",
                             status=RunStatus.to_string(RunStatus.FINISHED),
                             start_time=1,
                             end_time=1,
                             lifecycle_stage=LifecycleStage.ACTIVE),
            run_data=RunData(metrics=[Metric("key1", 123, 1, 0)],
                             params=[Param("my_param", "A")],
                             tags=[RunTag("tag1", "C")])),
        Run(run_info=RunInfo(run_uuid="7",
                             run_id="7",
                             experiment_id=1,
                             user_id="user-id",
                             status=RunStatus.to_string(RunStatus.FAILED),
                             start_time=1,
                             end_time=1,
                             lifecycle_stage=LifecycleStage.ACTIVE),
            run_data=RunData(metrics=[Metric("key1", 125, 1, 0)],
                             params=[Param("my_param", "B")],
                             tags=[RunTag("tag1", "D")])),
    ]
    sorted_runs = SearchUtils.sort(runs, order_bys)
    sorted_run_indices = []
    for run in sorted_runs:
        for i, r in enumerate(runs):
            if r == run:
                sorted_run_indices.append(i)
                break
    assert sorted_run_indices == matching_runs
Exemplo n.º 19
0
def _log_metric():
    request_message = _get_request_message(LogMetric())
    metric = Metric(request_message.key, request_message.value,
                    request_message.timestamp, request_message.step)
    run_id = request_message.run_id or request_message.run_uuid
    _get_tracking_store().log_metric(run_id, metric)
    response_message = LogMetric.Response()
    response = Response(mimetype="application/json")
    response.set_data(message_to_json(response_message))
    return response
Exemplo n.º 20
0
def _dict_to_run_metric_history(rm):
    metrics = rm["metrics"][::-1]
    return [
        Metric(
            key=rm["key"],
            value=float(m["value"]),
            timestamp=int(m["timestamp"]),
            step=int(m.get("step", i)),
        ) for (i, m) in enumerate(metrics)
    ]
Exemplo n.º 21
0
    def to_mlflow_entity(self):
        """
        Convert DB model to corresponding MLflow entity.

        :return: :py:class:`mlflow.entities.Metric`.
        """
        return Metric(key=self.key,
                      value=self.value if not self.is_nan else float("nan"),
                      timestamp=self.timestamp,
                      step=self.step)
Exemplo n.º 22
0
 def log_metric(self, run_id, key, value, timestamp=None, step=None):
     """
     Log a metric against the run ID. If timestamp is not provided, uses
     the current timestamp. The metric's step defaults to 0 if unspecified.
     """
     timestamp = timestamp if timestamp is not None else int(time.time())
     step = step if step is not None else 0
     _validate_metric(key, value, timestamp, step)
     metric = Metric(key, value, timestamp, step)
     self.store.log_metric(run_id, metric)
Exemplo n.º 23
0
 def _get_metric_from_line(metric_name, metric_line):
     metric_parts = metric_line.strip().split(" ")
     if len(metric_parts) != 2 and len(metric_parts) != 3:
         raise MlflowException("Metric '%s' is malformed; persisted metric data contained %s "
                               "fields. Expected 2 or 3 fields." %
                               (metric_name, len(metric_parts)), databricks_pb2.INTERNAL_ERROR)
     ts = int(metric_parts[0])
     val = float(metric_parts[1])
     step = int(metric_parts[2]) if len(metric_parts) == 3 else 0
     return Metric(key=metric_name, value=val, timestamp=ts, step=step)
Exemplo n.º 24
0
 def import_run_data(self, run_dct, run, src_user_id):
     from mlflow.entities import Metric, Param, RunTag
     now = round(time.time())
     params = [ Param(k,v) for k,v in run_dct['params'].items() ]
     metrics = [ Metric(k,v,now,0) for k,v in run_dct['metrics'].items() ] # TODO: missing timestamp and step semantics?
     tags = self._create_tags_for_metadata(run_dct['tags'])
     tags = utils.create_tags_for_mlflow_tags(tags, self.import_mlflow_tags)
     #utils.dump_tags("RunImporter.import_run_data",tags)
     #utils.set_dst_user_id(tags, src_user_id, self.use_src_user_id)
     self.client.log_batch(run.info.run_id, metrics, params, tags)
Exemplo n.º 25
0
 def _copy_run_data(self, src_run, dst_run_id):
     from mlflow.entities import Metric, Param, RunTag
     now = int(time.time()+.5)
     params = [ Param(k,v) for k,v in src_run.data.params.items() ]
     metrics = [ Metric(k,v,now,0) for k,v in src_run.data.metrics.items() ] # TODO: timestamp and step semantics?
     tags = utils.create_tags_for_metadata(self.src_client, src_run, self.export_metadata_tags)
     #tags = [ RunTag(k,v) for k,v in tags.items() ]
     tags = utils.create_tags_for_mlflow_tags(tags, self.import_mlflow_tags) # XX
     utils.set_dst_user_id(tags, src_run.info.user_id, self.use_src_user_id)
     self.dst_client.log_batch(dst_run_id, metrics, params, tags)
Exemplo n.º 26
0
 def test_log_batch(self):
     fs = FileStore(self.test_root)
     run = fs.create_run(experiment_id=FileStore.DEFAULT_EXPERIMENT_ID,
                         user_id='user',
                         start_time=0,
                         tags=[])
     run_id = run.info.run_id
     metric_entities = [
         Metric("m1", 0.87, 12345, 0),
         Metric("m2", 0.49, 12345, 0)
     ]
     param_entities = [Param("p1", "p1val"), Param("p2", "p2val")]
     tag_entities = [RunTag("t1", "t1val"), RunTag("t2", "t2val")]
     fs.log_batch(run_id=run_id,
                  metrics=metric_entities,
                  params=param_entities,
                  tags=tag_entities)
     self._verify_logged(fs, run_id, metric_entities, param_entities,
                         tag_entities)
Exemplo n.º 27
0
 def log_metric(self, run_id, key, value, timestamp=None, step=None):
     """
     Log a metric against the run ID. The timestamp defaults to the current timestamp.
     The step defaults to 0.
     """
     timestamp = timestamp if timestamp is not None else int(time.time())
     step = step if step is not None else 0
     _validate_metric(key, value, timestamp, step)
     metric = Metric(key, value, timestamp, step)
     self.store.log_metric(run_id, metric)
Exemplo n.º 28
0
def log_metrics(metrics):
    """
    Log multiple metrics for the current run, starting a run if no runs are active.
    :param metrics: Dictionary of metric_name: String -> value: Float
    :returns: None
    """
    run_id = _get_or_start_run().info.run_uuid
    timestamp = int(time.time())
    metrics_arr = [Metric(key, value, timestamp, 0) for key, value in metrics.items()]
    MlflowClient().log_batch(run_id=run_id, metrics=metrics_arr, params=[], tags=[])
Exemplo n.º 29
0
def test_log_batch(tracking_uri_mock, tmpdir):
    expected_metrics = {"metric-key0": 1.0, "metric-key1": 4.0}
    expected_params = {"param-key0": "param-val0", "param-key1": "param-val1"}
    exact_expected_tags = {"tag-key0": "tag-val0", "tag-key1": "tag-val1"}
    approx_expected_tags = set([MLFLOW_SOURCE_NAME, MLFLOW_SOURCE_TYPE])

    t = int(time.time())
    sorted_expected_metrics = sorted(expected_metrics.items(),
                                     key=lambda kv: kv[0])
    metrics = [
        Metric(key=key, value=value, timestamp=t, step=i)
        for i, (key, value) in enumerate(sorted_expected_metrics)
    ]
    params = [
        Param(key=key, value=value) for key, value in expected_params.items()
    ]
    tags = [
        RunTag(key=key, value=value)
        for key, value in exact_expected_tags.items()
    ]

    with start_run() as active_run:
        run_uuid = active_run.info.run_uuid
        mlflow.tracking.MlflowClient().log_batch(run_id=run_uuid,
                                                 metrics=metrics,
                                                 params=params,
                                                 tags=tags)
    finished_run = tracking.MlflowClient().get_run(run_uuid)
    # Validate metrics
    assert len(finished_run.data.metrics) == 2
    for key, value in finished_run.data.metrics.items():
        assert expected_metrics[key] == value
    # TODO: use client get_metric_history API here instead once it exists
    fs = FileStore(os.path.join(tmpdir.strpath, "mlruns"))
    metric_history0 = fs.get_metric_history(run_uuid, "metric-key0")
    assert set([(m.value, m.timestamp, m.step)
                for m in metric_history0]) == set([
                    (1.0, t, 0),
                ])
    metric_history1 = fs.get_metric_history(run_uuid, "metric-key1")
    assert set([(m.value, m.timestamp, m.step)
                for m in metric_history1]) == set([
                    (4.0, t, 1),
                ])

    # Validate tags (for automatically-set tags)
    assert len(finished_run.data.tags
               ) == len(exact_expected_tags) + len(approx_expected_tags)
    for tag_key, tag_value in finished_run.data.tags.items():
        if tag_key in approx_expected_tags:
            pass
        else:
            assert exact_expected_tags[tag_key] == tag_value
    # Validate params
    assert finished_run.data.params == expected_params
Exemplo n.º 30
0
def test_log_batch_validates_entity_names_and_values():
    with start_run() as active_run:
        run_id = active_run.info.run_id

        metrics = [
            Metric(key="../bad/metric/name", value=0.3, timestamp=3, step=0)
        ]
        with pytest.raises(MlflowException, match="Invalid metric name") as e:
            tracking.MlflowClient().log_batch(run_id, metrics=metrics)
        assert e.value.error_code == ErrorCode.Name(INVALID_PARAMETER_VALUE)

        metrics = [
            Metric(key="ok-name",
                   value="non-numerical-value",
                   timestamp=3,
                   step=0)
        ]
        with pytest.raises(MlflowException, match="Got invalid value") as e:
            tracking.MlflowClient().log_batch(run_id, metrics=metrics)
        assert e.value.error_code == ErrorCode.Name(INVALID_PARAMETER_VALUE)

        metrics = [
            Metric(key="ok-name",
                   value=0.3,
                   timestamp="non-numerical-timestamp",
                   step=0)
        ]
        with pytest.raises(MlflowException,
                           match="Got invalid timestamp") as e:
            tracking.MlflowClient().log_batch(run_id, metrics=metrics)
        assert e.value.error_code == ErrorCode.Name(INVALID_PARAMETER_VALUE)

        params = [Param(key="../bad/param/name", value="my-val")]
        with pytest.raises(MlflowException,
                           match="Invalid parameter name") as e:
            tracking.MlflowClient().log_batch(run_id, params=params)
        assert e.value.error_code == ErrorCode.Name(INVALID_PARAMETER_VALUE)

        tags = [Param(key="../bad/tag/name", value="my-val")]
        with pytest.raises(MlflowException, match="Invalid tag name") as e:
            tracking.MlflowClient().log_batch(run_id, tags=tags)
        assert e.value.error_code == ErrorCode.Name(INVALID_PARAMETER_VALUE)