コード例 #1
0
    def test_error_logging_to_deleted_run(self):
        exp = self._experiment_factory('error_logging')
        run_uuid = self._run_factory(
            self._get_run_configs(experiment_id=exp)).run_uuid

        self.store.delete_run(run_uuid)
        self.assertEqual(
            self.store.get_run(run_uuid).info.lifecycle_stage,
            entities.LifecycleStage.DELETED)
        with self.assertRaises(MlflowException) as e:
            self.store.log_param(run_uuid, entities.Param("p1345", "v1"))
        self.assertIn("must be in 'active' state", e.exception.message)

        with self.assertRaises(MlflowException) as e:
            self.store.log_metric(run_uuid, entities.Metric("m1345", 1.0, 123))
        self.assertIn("must be in 'active' state", e.exception.message)

        with self.assertRaises(MlflowException) as e:
            self.store.set_tag(run_uuid, entities.RunTag("t1345", "tv1"))
        self.assertIn("must be in 'active' state", e.exception.message)

        # restore this run and try again
        self.store.restore_run(run_uuid)
        self.assertEqual(
            self.store.get_run(run_uuid).info.lifecycle_stage,
            entities.LifecycleStage.ACTIVE)
        self.store.log_param(run_uuid, entities.Param("p1345", "v22"))
        self.store.log_metric(run_uuid,
                              entities.Metric("m1345", 34.0,
                                              85))  # earlier timestamp
        self.store.set_tag(run_uuid, entities.RunTag("t1345", "tv44"))

        run = self.store.get_run(run_uuid)
        assert len(run.data.params) == 1
        p = run.data.params[0]
        self.assertEqual(p.key, "p1345")
        self.assertEqual(p.value, "v22")
        assert len(run.data.metrics) == 1
        m = run.data.metrics[0]
        self.assertEqual(m.key, "m1345")
        self.assertEqual(m.value, 34.0)
        run = self.store.get_run(run_uuid)
        self.assertEqual([("p1345", "v22")],
                         [(p.key, p.value)
                          for p in run.data.params if p.key == "p1345"])
        self.assertEqual([("m1345", 34.0, 85)],
                         [(m.key, m.value, m.timestamp)
                          for m in run.data.metrics if m.key == "m1345"])
        self.assertEqual([("t1345", "tv44")],
                         [(t.key, t.value)
                          for t in run.data.tags if t.key == "t1345"])
コード例 #2
0
 def test_log_batch_param_overwrite_disallowed_single_req(self):
     # Test that attempting to overwrite a param via log_batch results in an exception
     run = self._run_factory()
     pkey = "common-key"
     param0 = entities.Param(pkey, "orig-val")
     param1 = entities.Param(pkey, 'newval')
     tag = entities.RunTag("tag-key", "tag-val")
     metric = entities.Metric("metric-key", 3.0, 12345)
     with self.assertRaises(MlflowException) as e:
         self.store.log_batch(run.info.run_uuid, metrics=[metric], params=[param0, param1],
                              tags=[tag])
     self.assertIn("Changing param value is not allowed. Param with key=", e.exception.message)
     assert e.exception.error_code == ErrorCode.Name(INVALID_PARAMETER_VALUE)
     self._verify_logged(run.info.run_uuid, metrics=[], params=[param0], tags=[])
コード例 #3
0
    def test_set_tag(self):
        run = self._run_factory()

        tkey = 'test tag'
        tval = 'a boogie'
        tag = entities.RunTag(tkey, tval)
        self.store.set_tag(run.info.run_uuid, tag)

        run = self.store.get_run(run.info.run_uuid)

        found = False
        for m in run.data.tags:
            if m.key == tkey and m.value == tval:
                found = True

        self.assertTrue(found)
コード例 #4
0
    def test_log_batch_param_overwrite_disallowed(self):
        # Test that attempting to overwrite a param via log_batch results in an exception and that
        # no partial data is logged
        run = self._run_factory()
        self.session.commit()
        tkey = 'my-param'
        param = entities.Param(tkey, 'orig-val')
        self.store.log_param(run.run_uuid, param)

        overwrite_param = entities.Param(tkey, 'newval')
        tag = entities.RunTag("tag-key", "tag-val")
        metric = entities.Metric("metric-key", 3.0, 12345)
        with self.assertRaises(MlflowException) as e:
            self.store.log_batch(run.run_uuid, metrics=[metric], params=[overwrite_param],
                                 tags=[tag])
        self.assertIn("Changing param value is not allowed. Param with key=", e.exception.message)
        assert e.exception.error_code == ErrorCode.Name(INVALID_PARAMETER_VALUE)
        self._verify_logged(run.run_uuid, metrics=[], params=[param], tags=[])
コード例 #5
0
    def test_set_tag(self):
        run = self._run_factory()

        self.session.commit()

        tkey = 'test tag'
        tval = 'a boogie'
        tag = entities.RunTag(tkey, tval)
        self.store.set_tag(run.run_uuid, tag)

        actual = self.session.query(models.SqlTag).filter_by(key=tkey, value=tval)

        self.assertIsNotNone(actual)

        run = self.store.get_run(run.run_uuid)

        found = False
        for m in run.data.tags:
            if m.key == tkey and m.value == tval:
                found = True

        self.assertTrue(found)