Exemple #1
0
    def test_log_metric(self):
        run = self._run_factory()

        tkey = 'blahmetric'
        tval = 100.0
        metric = entities.Metric(tkey, tval, int(time.time()))
        metric2 = entities.Metric(tkey, tval, int(time.time()) + 2)
        self.store.log_metric(run.info.run_uuid, metric)
        self.store.log_metric(run.info.run_uuid, metric2)

        run = self.store.get_run(run.info.run_uuid)
        found = False
        for m in run.data.metrics:
            if m.key == tkey and m.value == tval:
                found = True

        self.assertTrue(found)

        # SQL store _get_run method returns full history of recorded metrics.
        # Should return duplicates as well
        # MLflow RunData contains only the last reported values for metrics.
        with self.store.ManagedSessionMaker() as session:
            sql_run_metrics = self.store._get_run(session, run.info.run_uuid).metrics
            self.assertEqual(2, len(sql_run_metrics))
            self.assertEqual(1, len(run.data.metrics))
    def test_log_metric(self):
        run = self._run_factory()

        self.session.commit()

        tkey = 'blahmetric'
        tval = 100.0
        metric = entities.Metric(tkey, tval, int(time.time()))
        metric2 = entities.Metric(tkey, tval, int(time.time()) + 2)
        self.store.log_metric(run.run_uuid, metric)
        self.store.log_metric(run.run_uuid, metric2)

        actual = self.session.query(models.SqlMetric).filter_by(key=tkey,
                                                                value=tval)

        self.assertIsNotNone(actual)

        run = self.store.get_run(run.run_uuid)

        # SQL store _get_run method returns full history of recorded metrics.
        # Should return duplicates as well
        # MLflow RunData contains only the last reported values for metrics.
        sql_run_metrics = self.store._get_run(run.info.run_uuid).metrics
        self.assertEqual(2, len(sql_run_metrics))
        self.assertEqual(1, len(run.data.metrics))

        found = False
        for m in run.data.metrics:
            if m.key == tkey and m.value == tval:
                found = True

        self.assertTrue(found)
Exemple #3
0
    def test_log_metric(self):
        run = self._run_factory()

        self.session.commit()

        tkey = 'blahmetric'
        tval = 100.0
        metric = entities.Metric(tkey, tval, int(time.time()))
        metric2 = entities.Metric(tkey, tval, int(time.time()) + 2)
        self.store.log_metric(run.run_uuid, metric)
        self.store.log_metric(run.run_uuid, metric2)

        actual = self.session.query(models.SqlMetric).filter_by(key=tkey,
                                                                value=tval)

        self.assertIsNotNone(actual)

        run = self.store.get_run(run.run_uuid)

        self.assertEqual(4, len(run.data.metrics))
        found = False
        for m in run.data.metrics:
            if m.key == tkey and m.value == tval:
                found = True

        self.assertTrue(found)
Exemple #4
0
    def test_log_metric_uniqueness(self):
        run = self._run_factory()

        tkey = 'blahmetric'
        tval = 100.0
        metric = entities.Metric(tkey, tval, int(time.time()))
        metric2 = entities.Metric(tkey, 1.02, int(time.time()))
        self.store.log_metric(run.info.run_uuid, metric)

        with self.assertRaises(MlflowException) as e:
            self.store.log_metric(run.info.run_uuid, metric2)
        self.assertIn("must be unique. Metric already logged value", e.exception.message)
Exemple #5
0
    def test_log_metric_uniqueness(self):
        run = self._run_factory()

        self.session.commit()

        tkey = 'blahmetric'
        tval = 100.0
        metric = entities.Metric(tkey, tval, int(time.time()))
        metric2 = entities.Metric(tkey, 1.02, int(time.time()))
        self.store.log_metric(run.run_uuid, metric)

        with self.assertRaises(MlflowException):
            self.store.log_metric(run.run_uuid, metric2)
    def test_error_logging_to_deleted_run(self):
        exp = self._experiment_factory('error_logging')
        run_uuid = self._run_factory(
            self._get_run_configs(experiment_id=exp)).run_uuid

        self.store.delete_run(run_uuid)
        self.assertEqual(
            self.store.get_run(run_uuid).info.lifecycle_stage,
            entities.LifecycleStage.DELETED)
        with self.assertRaises(MlflowException) as e:
            self.store.log_param(run_uuid, entities.Param("p1345", "v1"))
        self.assertIn("must be in 'active' state", e.exception.message)

        with self.assertRaises(MlflowException) as e:
            self.store.log_metric(run_uuid, entities.Metric("m1345", 1.0, 123))
        self.assertIn("must be in 'active' state", e.exception.message)

        with self.assertRaises(MlflowException) as e:
            self.store.set_tag(run_uuid, entities.RunTag("t1345", "tv1"))
        self.assertIn("must be in 'active' state", e.exception.message)

        # restore this run and try again
        self.store.restore_run(run_uuid)
        self.assertEqual(
            self.store.get_run(run_uuid).info.lifecycle_stage,
            entities.LifecycleStage.ACTIVE)
        self.store.log_param(run_uuid, entities.Param("p1345", "v22"))
        self.store.log_metric(run_uuid,
                              entities.Metric("m1345", 34.0,
                                              85))  # earlier timestamp
        self.store.set_tag(run_uuid, entities.RunTag("t1345", "tv44"))

        run = self.store.get_run(run_uuid)
        assert len(run.data.params) == 1
        p = run.data.params[0]
        self.assertEqual(p.key, "p1345")
        self.assertEqual(p.value, "v22")
        assert len(run.data.metrics) == 1
        m = run.data.metrics[0]
        self.assertEqual(m.key, "m1345")
        self.assertEqual(m.value, 34.0)
        run = self.store.get_run(run_uuid)
        self.assertEqual([("p1345", "v22")],
                         [(p.key, p.value)
                          for p in run.data.params if p.key == "p1345"])
        self.assertEqual([("m1345", 34.0, 85)],
                         [(m.key, m.value, m.timestamp)
                          for m in run.data.metrics if m.key == "m1345"])
        self.assertEqual([("t1345", "tv44")],
                         [(t.key, t.value)
                          for t in run.data.tags if t.key == "t1345"])
Exemple #7
0
    def test_log_null_metric(self):
        run = self._run_factory()

        tkey = 'blahmetric'
        tval = None
        metric = entities.Metric(tkey, tval, int(time.time()))

        with self.assertRaises(MlflowException) as exception_context:
            self.store.log_metric(run.info.run_uuid, metric)
        assert exception_context.exception.error_code == ErrorCode.Name(INTERNAL_ERROR)
Exemple #8
0
    def test_log_null_metric(self):
        run = self._run_factory()

        tkey = 'blahmetric'
        tval = None
        metric = entities.Metric(tkey, tval, int(time.time()))

        warnings.simplefilter("ignore")
        with self.assertRaises(MlflowException) as exception_context, warnings.catch_warnings():
            self.store.log_metric(run.info.run_uuid, metric)
            warnings.resetwarnings()
        assert exception_context.exception.error_code == ErrorCode.Name(INTERNAL_ERROR)
Exemple #9
0
    def test_log_null_metric(self):
        run = self._run_factory()

        self.session.commit()

        tkey = 'blahmetric'
        tval = None
        metric = entities.Metric(tkey, tval, int(time.time()))

        with self.assertRaises(MlflowException) as e:
            self.store.log_metric(run.run_uuid, metric)
        self.assertIn("Log metric request failed for run ID=", e.exception.message)
        self.assertIn("IntegrityError", e.exception.message)
Exemple #10
0
 def test_log_batch_param_overwrite_disallowed_single_req(self):
     # Test that attempting to overwrite a param via log_batch results in an exception
     run = self._run_factory()
     pkey = "common-key"
     param0 = entities.Param(pkey, "orig-val")
     param1 = entities.Param(pkey, 'newval')
     tag = entities.RunTag("tag-key", "tag-val")
     metric = entities.Metric("metric-key", 3.0, 12345)
     with self.assertRaises(MlflowException) as e:
         self.store.log_batch(run.info.run_uuid, metrics=[metric], params=[param0, param1],
                              tags=[tag])
     self.assertIn("Changing param value is not allowed. Param with key=", e.exception.message)
     assert e.exception.error_code == ErrorCode.Name(INVALID_PARAMETER_VALUE)
     self._verify_logged(run.info.run_uuid, metrics=[], params=[param0], tags=[])
    def test_search_full(self):
        experiment_id = self._experiment_factory('search_params')
        r1 = self._run_factory(self._get_run_configs('r1',
                                                     experiment_id)).run_uuid
        r2 = self._run_factory(self._get_run_configs('r2',
                                                     experiment_id)).run_uuid

        self.store.log_param(r1, entities.Param('generic_param', 'p_val'))
        self.store.log_param(r2, entities.Param('generic_param', 'p_val'))

        self.store.log_param(r1, entities.Param('p_a', 'abc'))
        self.store.log_param(r2, entities.Param('p_b', 'ABC'))

        self.store.log_metric(r1, entities.Metric("common", 1.0, 1))
        self.store.log_metric(r2, entities.Metric("common", 1.0, 1))

        self.store.log_metric(r1, entities.Metric("m_a", 2.0, 2))
        self.store.log_metric(r2, entities.Metric("m_b", 3.0, 2))
        self.store.log_metric(r2, entities.Metric("m_b", 4.0, 8))
        self.store.log_metric(r2, entities.Metric("m_b", 8.0, 3))

        p_expr = self._param_expression("generic_param", "=", "p_val")
        m_expr = self._metric_expression("common", "=", 1.0)
        self.assertSequenceEqual([r1, r2],
                                 self._search(experiment_id,
                                              param_expressions=[p_expr],
                                              metrics_expressions=[m_expr]))

        # all params and metrics match
        p_expr = self._param_expression("generic_param", "=", "p_val")
        m1_expr = self._metric_expression("common", "=", 1.0)
        m2_expr = self._metric_expression("m_a", ">", 1.0)
        self.assertSequenceEqual([r1],
                                 self._search(
                                     experiment_id,
                                     param_expressions=[p_expr],
                                     metrics_expressions=[m1_expr, m2_expr]))

        # test with mismatch param
        p_expr = self._param_expression("random_bad_name", "=", "p_val")
        m1_expr = self._metric_expression("common", "=", 1.0)
        m2_expr = self._metric_expression("m_a", ">", 1.0)
        self.assertSequenceEqual([],
                                 self._search(
                                     experiment_id,
                                     param_expressions=[p_expr],
                                     metrics_expressions=[m1_expr, m2_expr]))

        # test with mismatch metric
        p_expr = self._param_expression("generic_param", "=", "p_val")
        m1_expr = self._metric_expression("common", "=", 1.0)
        m2_expr = self._metric_expression("m_a", ">", 100.0)
        self.assertSequenceEqual([],
                                 self._search(
                                     experiment_id,
                                     param_expressions=[p_expr],
                                     metrics_expressions=[m1_expr, m2_expr]))
    def test_log_batch_param_overwrite_disallowed(self):
        # Test that attempting to overwrite a param via log_batch results in an exception and that
        # no partial data is logged
        run = self._run_factory()
        self.session.commit()
        tkey = 'my-param'
        param = entities.Param(tkey, 'orig-val')
        self.store.log_param(run.run_uuid, param)

        overwrite_param = entities.Param(tkey, 'newval')
        tag = entities.RunTag("tag-key", "tag-val")
        metric = entities.Metric("metric-key", 3.0, 12345)
        with self.assertRaises(MlflowException) as e:
            self.store.log_batch(run.run_uuid, metrics=[metric], params=[overwrite_param],
                                 tags=[tag])
        self.assertIn("Changing param value is not allowed. Param with key=", e.exception.message)
        assert e.exception.error_code == ErrorCode.Name(INVALID_PARAMETER_VALUE)
        self._verify_logged(run.run_uuid, metrics=[], params=[param], tags=[])
    def test_search_metrics(self):
        experiment_id = self._experiment_factory('search_params')
        r1 = self._run_factory(self._get_run_configs('r1',
                                                     experiment_id)).run_uuid
        r2 = self._run_factory(self._get_run_configs('r2',
                                                     experiment_id)).run_uuid

        self.store.log_metric(r1, entities.Metric("common", 1.0, 1))
        self.store.log_metric(r2, entities.Metric("common", 1.0, 1))

        self.store.log_metric(r1, entities.Metric("measure_a", 1.0, 1))
        self.store.log_metric(r2, entities.Metric("measure_a", 200.0, 2))
        self.store.log_metric(r2, entities.Metric("measure_a", 400.0, 3))

        self.store.log_metric(r1, entities.Metric("m_a", 2.0, 2))
        self.store.log_metric(r2, entities.Metric("m_b", 3.0, 2))
        self.store.log_metric(r2, entities.Metric("m_b", 4.0,
                                                  8))  # this is last timestamp
        self.store.log_metric(r2, entities.Metric("m_b", 8.0, 3))

        expr = self._metric_expression("common", "=", 1.0)
        self.assertSequenceEqual([r1, r2],
                                 self._search(experiment_id,
                                              param_expressions=[expr]))

        expr = self._metric_expression("common", ">", 0.0)
        self.assertSequenceEqual([r1, r2],
                                 self._search(experiment_id,
                                              param_expressions=[expr]))

        expr = self._metric_expression("common", ">=", 0.0)
        self.assertSequenceEqual([r1, r2],
                                 self._search(experiment_id,
                                              param_expressions=[expr]))

        expr = self._metric_expression("common", "<", 4.0)
        self.assertSequenceEqual([r1, r2],
                                 self._search(experiment_id,
                                              param_expressions=[expr]))

        expr = self._metric_expression("common", "<=", 4.0)
        self.assertSequenceEqual([r1, r2],
                                 self._search(experiment_id,
                                              param_expressions=[expr]))

        expr = self._metric_expression("common", "!=", 1.0)
        self.assertSequenceEqual([],
                                 self._search(experiment_id,
                                              param_expressions=[expr]))

        expr = self._metric_expression("common", ">=", 3.0)
        self.assertSequenceEqual([],
                                 self._search(experiment_id,
                                              param_expressions=[expr]))

        expr = self._metric_expression("common", "<=", 0.75)
        self.assertSequenceEqual([],
                                 self._search(experiment_id,
                                              param_expressions=[expr]))

        # tests for same metric name across runs with different values and timestamps
        expr = self._metric_expression("measure_a", ">", 0.0)
        self.assertSequenceEqual([r1, r2],
                                 self._search(experiment_id,
                                              param_expressions=[expr]))

        expr = self._metric_expression("measure_a", "<", 50.0)
        self.assertSequenceEqual([r1],
                                 self._search(experiment_id,
                                              param_expressions=[expr]))

        expr = self._metric_expression("measure_a", "<", 1000.0)
        self.assertSequenceEqual([r1, r2],
                                 self._search(experiment_id,
                                              param_expressions=[expr]))

        expr = self._metric_expression("measure_a", "!=", -12.0)
        self.assertSequenceEqual([r1, r2],
                                 self._search(experiment_id,
                                              param_expressions=[expr]))

        expr = self._metric_expression("measure_a", ">", 50.0)
        self.assertSequenceEqual([r2],
                                 self._search(experiment_id,
                                              param_expressions=[expr]))

        expr = self._metric_expression("measure_a", "=", 1.0)
        self.assertSequenceEqual([r1],
                                 self._search(experiment_id,
                                              param_expressions=[expr]))

        expr = self._metric_expression("measure_a", "=", 400.0)
        self.assertSequenceEqual([r2],
                                 self._search(experiment_id,
                                              param_expressions=[expr]))

        # test search with unique metric keys
        expr = self._metric_expression("m_a", ">", 1.0)
        self.assertSequenceEqual([r1],
                                 self._search(experiment_id,
                                              param_expressions=[expr]))

        expr = self._metric_expression("m_b", ">", 1.0)
        self.assertSequenceEqual([r2],
                                 self._search(experiment_id,
                                              param_expressions=[expr]))

        # there is a recorded metric this threshold but not last timestamp
        expr = self._metric_expression("m_b", ">", 5.0)
        self.assertSequenceEqual([],
                                 self._search(experiment_id,
                                              param_expressions=[expr]))

        # metrics matches last reported timestamp for 'm_b'
        expr = self._metric_expression("m_b", "=", 4.0)
        self.assertSequenceEqual([r2],
                                 self._search(experiment_id,
                                              param_expressions=[expr]))