def test_log_metric(self): run = self._run_factory() tkey = 'blahmetric' tval = 100.0 metric = entities.Metric(tkey, tval, int(time.time())) metric2 = entities.Metric(tkey, tval, int(time.time()) + 2) self.store.log_metric(run.info.run_uuid, metric) self.store.log_metric(run.info.run_uuid, metric2) run = self.store.get_run(run.info.run_uuid) found = False for m in run.data.metrics: if m.key == tkey and m.value == tval: found = True self.assertTrue(found) # SQL store _get_run method returns full history of recorded metrics. # Should return duplicates as well # MLflow RunData contains only the last reported values for metrics. with self.store.ManagedSessionMaker() as session: sql_run_metrics = self.store._get_run(session, run.info.run_uuid).metrics self.assertEqual(2, len(sql_run_metrics)) self.assertEqual(1, len(run.data.metrics))
def test_log_metric(self): run = self._run_factory() self.session.commit() tkey = 'blahmetric' tval = 100.0 metric = entities.Metric(tkey, tval, int(time.time())) metric2 = entities.Metric(tkey, tval, int(time.time()) + 2) self.store.log_metric(run.run_uuid, metric) self.store.log_metric(run.run_uuid, metric2) actual = self.session.query(models.SqlMetric).filter_by(key=tkey, value=tval) self.assertIsNotNone(actual) run = self.store.get_run(run.run_uuid) # SQL store _get_run method returns full history of recorded metrics. # Should return duplicates as well # MLflow RunData contains only the last reported values for metrics. sql_run_metrics = self.store._get_run(run.info.run_uuid).metrics self.assertEqual(2, len(sql_run_metrics)) self.assertEqual(1, len(run.data.metrics)) found = False for m in run.data.metrics: if m.key == tkey and m.value == tval: found = True self.assertTrue(found)
def test_log_metric(self): run = self._run_factory() self.session.commit() tkey = 'blahmetric' tval = 100.0 metric = entities.Metric(tkey, tval, int(time.time())) metric2 = entities.Metric(tkey, tval, int(time.time()) + 2) self.store.log_metric(run.run_uuid, metric) self.store.log_metric(run.run_uuid, metric2) actual = self.session.query(models.SqlMetric).filter_by(key=tkey, value=tval) self.assertIsNotNone(actual) run = self.store.get_run(run.run_uuid) self.assertEqual(4, len(run.data.metrics)) found = False for m in run.data.metrics: if m.key == tkey and m.value == tval: found = True self.assertTrue(found)
def test_log_metric_uniqueness(self): run = self._run_factory() tkey = 'blahmetric' tval = 100.0 metric = entities.Metric(tkey, tval, int(time.time())) metric2 = entities.Metric(tkey, 1.02, int(time.time())) self.store.log_metric(run.info.run_uuid, metric) with self.assertRaises(MlflowException) as e: self.store.log_metric(run.info.run_uuid, metric2) self.assertIn("must be unique. Metric already logged value", e.exception.message)
def test_log_metric_uniqueness(self): run = self._run_factory() self.session.commit() tkey = 'blahmetric' tval = 100.0 metric = entities.Metric(tkey, tval, int(time.time())) metric2 = entities.Metric(tkey, 1.02, int(time.time())) self.store.log_metric(run.run_uuid, metric) with self.assertRaises(MlflowException): self.store.log_metric(run.run_uuid, metric2)
def test_error_logging_to_deleted_run(self): exp = self._experiment_factory('error_logging') run_uuid = self._run_factory( self._get_run_configs(experiment_id=exp)).run_uuid self.store.delete_run(run_uuid) self.assertEqual( self.store.get_run(run_uuid).info.lifecycle_stage, entities.LifecycleStage.DELETED) with self.assertRaises(MlflowException) as e: self.store.log_param(run_uuid, entities.Param("p1345", "v1")) self.assertIn("must be in 'active' state", e.exception.message) with self.assertRaises(MlflowException) as e: self.store.log_metric(run_uuid, entities.Metric("m1345", 1.0, 123)) self.assertIn("must be in 'active' state", e.exception.message) with self.assertRaises(MlflowException) as e: self.store.set_tag(run_uuid, entities.RunTag("t1345", "tv1")) self.assertIn("must be in 'active' state", e.exception.message) # restore this run and try again self.store.restore_run(run_uuid) self.assertEqual( self.store.get_run(run_uuid).info.lifecycle_stage, entities.LifecycleStage.ACTIVE) self.store.log_param(run_uuid, entities.Param("p1345", "v22")) self.store.log_metric(run_uuid, entities.Metric("m1345", 34.0, 85)) # earlier timestamp self.store.set_tag(run_uuid, entities.RunTag("t1345", "tv44")) run = self.store.get_run(run_uuid) assert len(run.data.params) == 1 p = run.data.params[0] self.assertEqual(p.key, "p1345") self.assertEqual(p.value, "v22") assert len(run.data.metrics) == 1 m = run.data.metrics[0] self.assertEqual(m.key, "m1345") self.assertEqual(m.value, 34.0) run = self.store.get_run(run_uuid) self.assertEqual([("p1345", "v22")], [(p.key, p.value) for p in run.data.params if p.key == "p1345"]) self.assertEqual([("m1345", 34.0, 85)], [(m.key, m.value, m.timestamp) for m in run.data.metrics if m.key == "m1345"]) self.assertEqual([("t1345", "tv44")], [(t.key, t.value) for t in run.data.tags if t.key == "t1345"])
def test_log_null_metric(self): run = self._run_factory() tkey = 'blahmetric' tval = None metric = entities.Metric(tkey, tval, int(time.time())) with self.assertRaises(MlflowException) as exception_context: self.store.log_metric(run.info.run_uuid, metric) assert exception_context.exception.error_code == ErrorCode.Name(INTERNAL_ERROR)
def test_log_null_metric(self): run = self._run_factory() tkey = 'blahmetric' tval = None metric = entities.Metric(tkey, tval, int(time.time())) warnings.simplefilter("ignore") with self.assertRaises(MlflowException) as exception_context, warnings.catch_warnings(): self.store.log_metric(run.info.run_uuid, metric) warnings.resetwarnings() assert exception_context.exception.error_code == ErrorCode.Name(INTERNAL_ERROR)
def test_log_null_metric(self): run = self._run_factory() self.session.commit() tkey = 'blahmetric' tval = None metric = entities.Metric(tkey, tval, int(time.time())) with self.assertRaises(MlflowException) as e: self.store.log_metric(run.run_uuid, metric) self.assertIn("Log metric request failed for run ID=", e.exception.message) self.assertIn("IntegrityError", e.exception.message)
def test_log_batch_param_overwrite_disallowed_single_req(self): # Test that attempting to overwrite a param via log_batch results in an exception run = self._run_factory() pkey = "common-key" param0 = entities.Param(pkey, "orig-val") param1 = entities.Param(pkey, 'newval') tag = entities.RunTag("tag-key", "tag-val") metric = entities.Metric("metric-key", 3.0, 12345) with self.assertRaises(MlflowException) as e: self.store.log_batch(run.info.run_uuid, metrics=[metric], params=[param0, param1], tags=[tag]) self.assertIn("Changing param value is not allowed. Param with key=", e.exception.message) assert e.exception.error_code == ErrorCode.Name(INVALID_PARAMETER_VALUE) self._verify_logged(run.info.run_uuid, metrics=[], params=[param0], tags=[])
def test_search_full(self): experiment_id = self._experiment_factory('search_params') r1 = self._run_factory(self._get_run_configs('r1', experiment_id)).run_uuid r2 = self._run_factory(self._get_run_configs('r2', experiment_id)).run_uuid self.store.log_param(r1, entities.Param('generic_param', 'p_val')) self.store.log_param(r2, entities.Param('generic_param', 'p_val')) self.store.log_param(r1, entities.Param('p_a', 'abc')) self.store.log_param(r2, entities.Param('p_b', 'ABC')) self.store.log_metric(r1, entities.Metric("common", 1.0, 1)) self.store.log_metric(r2, entities.Metric("common", 1.0, 1)) self.store.log_metric(r1, entities.Metric("m_a", 2.0, 2)) self.store.log_metric(r2, entities.Metric("m_b", 3.0, 2)) self.store.log_metric(r2, entities.Metric("m_b", 4.0, 8)) self.store.log_metric(r2, entities.Metric("m_b", 8.0, 3)) p_expr = self._param_expression("generic_param", "=", "p_val") m_expr = self._metric_expression("common", "=", 1.0) self.assertSequenceEqual([r1, r2], self._search(experiment_id, param_expressions=[p_expr], metrics_expressions=[m_expr])) # all params and metrics match p_expr = self._param_expression("generic_param", "=", "p_val") m1_expr = self._metric_expression("common", "=", 1.0) m2_expr = self._metric_expression("m_a", ">", 1.0) self.assertSequenceEqual([r1], self._search( experiment_id, param_expressions=[p_expr], metrics_expressions=[m1_expr, m2_expr])) # test with mismatch param p_expr = self._param_expression("random_bad_name", "=", "p_val") m1_expr = self._metric_expression("common", "=", 1.0) m2_expr = self._metric_expression("m_a", ">", 1.0) self.assertSequenceEqual([], self._search( experiment_id, param_expressions=[p_expr], metrics_expressions=[m1_expr, m2_expr])) # test with mismatch metric p_expr = self._param_expression("generic_param", "=", "p_val") m1_expr = self._metric_expression("common", "=", 1.0) m2_expr = self._metric_expression("m_a", ">", 100.0) self.assertSequenceEqual([], self._search( experiment_id, param_expressions=[p_expr], metrics_expressions=[m1_expr, m2_expr]))
def test_log_batch_param_overwrite_disallowed(self): # Test that attempting to overwrite a param via log_batch results in an exception and that # no partial data is logged run = self._run_factory() self.session.commit() tkey = 'my-param' param = entities.Param(tkey, 'orig-val') self.store.log_param(run.run_uuid, param) overwrite_param = entities.Param(tkey, 'newval') tag = entities.RunTag("tag-key", "tag-val") metric = entities.Metric("metric-key", 3.0, 12345) with self.assertRaises(MlflowException) as e: self.store.log_batch(run.run_uuid, metrics=[metric], params=[overwrite_param], tags=[tag]) self.assertIn("Changing param value is not allowed. Param with key=", e.exception.message) assert e.exception.error_code == ErrorCode.Name(INVALID_PARAMETER_VALUE) self._verify_logged(run.run_uuid, metrics=[], params=[param], tags=[])
def test_search_metrics(self): experiment_id = self._experiment_factory('search_params') r1 = self._run_factory(self._get_run_configs('r1', experiment_id)).run_uuid r2 = self._run_factory(self._get_run_configs('r2', experiment_id)).run_uuid self.store.log_metric(r1, entities.Metric("common", 1.0, 1)) self.store.log_metric(r2, entities.Metric("common", 1.0, 1)) self.store.log_metric(r1, entities.Metric("measure_a", 1.0, 1)) self.store.log_metric(r2, entities.Metric("measure_a", 200.0, 2)) self.store.log_metric(r2, entities.Metric("measure_a", 400.0, 3)) self.store.log_metric(r1, entities.Metric("m_a", 2.0, 2)) self.store.log_metric(r2, entities.Metric("m_b", 3.0, 2)) self.store.log_metric(r2, entities.Metric("m_b", 4.0, 8)) # this is last timestamp self.store.log_metric(r2, entities.Metric("m_b", 8.0, 3)) expr = self._metric_expression("common", "=", 1.0) self.assertSequenceEqual([r1, r2], self._search(experiment_id, param_expressions=[expr])) expr = self._metric_expression("common", ">", 0.0) self.assertSequenceEqual([r1, r2], self._search(experiment_id, param_expressions=[expr])) expr = self._metric_expression("common", ">=", 0.0) self.assertSequenceEqual([r1, r2], self._search(experiment_id, param_expressions=[expr])) expr = self._metric_expression("common", "<", 4.0) self.assertSequenceEqual([r1, r2], self._search(experiment_id, param_expressions=[expr])) expr = self._metric_expression("common", "<=", 4.0) self.assertSequenceEqual([r1, r2], self._search(experiment_id, param_expressions=[expr])) expr = self._metric_expression("common", "!=", 1.0) self.assertSequenceEqual([], self._search(experiment_id, param_expressions=[expr])) expr = self._metric_expression("common", ">=", 3.0) self.assertSequenceEqual([], self._search(experiment_id, param_expressions=[expr])) expr = self._metric_expression("common", "<=", 0.75) self.assertSequenceEqual([], self._search(experiment_id, param_expressions=[expr])) # tests for same metric name across runs with different values and timestamps expr = self._metric_expression("measure_a", ">", 0.0) self.assertSequenceEqual([r1, r2], self._search(experiment_id, param_expressions=[expr])) expr = self._metric_expression("measure_a", "<", 50.0) self.assertSequenceEqual([r1], self._search(experiment_id, param_expressions=[expr])) expr = self._metric_expression("measure_a", "<", 1000.0) self.assertSequenceEqual([r1, r2], self._search(experiment_id, param_expressions=[expr])) expr = self._metric_expression("measure_a", "!=", -12.0) self.assertSequenceEqual([r1, r2], self._search(experiment_id, param_expressions=[expr])) expr = self._metric_expression("measure_a", ">", 50.0) self.assertSequenceEqual([r2], self._search(experiment_id, param_expressions=[expr])) expr = self._metric_expression("measure_a", "=", 1.0) self.assertSequenceEqual([r1], self._search(experiment_id, param_expressions=[expr])) expr = self._metric_expression("measure_a", "=", 400.0) self.assertSequenceEqual([r2], self._search(experiment_id, param_expressions=[expr])) # test search with unique metric keys expr = self._metric_expression("m_a", ">", 1.0) self.assertSequenceEqual([r1], self._search(experiment_id, param_expressions=[expr])) expr = self._metric_expression("m_b", ">", 1.0) self.assertSequenceEqual([r2], self._search(experiment_id, param_expressions=[expr])) # there is a recorded metric this threshold but not last timestamp expr = self._metric_expression("m_b", ">", 5.0) self.assertSequenceEqual([], self._search(experiment_id, param_expressions=[expr])) # metrics matches last reported timestamp for 'm_b' expr = self._metric_expression("m_b", "=", 4.0) self.assertSequenceEqual([r2], self._search(experiment_id, param_expressions=[expr]))