예제 #1
0
 def setUp(self):
     self.store = SqlAlchemyStore(DB_URI)
     self.engine = sqlalchemy.create_engine(DB_URI)
     Session = sqlalchemy.orm.sessionmaker(bind=self.engine)
     self.session = Session()
     self.store.session = self.session
     self.store.engine = self.engine
     models.Base.metadata.create_all(self.engine)
예제 #2
0
 def setUp(self):
     self.maxDiff = None  # print all differences on assert failures
     self.store = SqlAlchemyStore(DB_URI)
     self.engine = sqlalchemy.create_engine(DB_URI)
     Session = sqlalchemy.orm.sessionmaker(bind=self.engine)
     self.session = Session()
     self.store.session = self.session
     self.store.engine = self.engine
     models.Base.metadata.create_all(self.engine)
예제 #3
0
def test_store_generated_schema_matches_base(tmpdir, db_url):
    # Create a SQLAlchemyStore against tmpfile, directly verify that tmpfile contains a
    # database with a valid schema
    SqlAlchemyStore(db_url, tmpdir.join("ARTIFACTS").strpath)
    engine = sqlalchemy.create_engine(db_url)
    mc = MigrationContext.configure(engine.connect())
    diff = compare_metadata(mc, Base.metadata)
    assert len(diff) == 0
예제 #4
0
def dump_sqlalchemy_store_schema(dst_file):
    db_tmpdir = tempfile.mkdtemp()
    try:
        path = os.path.join(db_tmpdir, "db_file")
        db_url = "sqlite:///%s" % path
        SqlAlchemyStore(db_url, db_tmpdir)
        dump_db_schema(db_url, dst_file)
    finally:
        shutil.rmtree(db_tmpdir)
예제 #5
0
def test_sqlalchemystore_idempotently_generates_up_to_date_schema(
        tmpdir, db_url, expected_schema_file):
    generated_schema_file = tmpdir.join("generated-schema.sql").strpath
    # Repeatedly initialize a SQLAlchemyStore against the same DB URL. Initialization should
    # succeed and the schema should be the same.
    for _ in range(3):
        SqlAlchemyStore(db_url, tmpdir.join("ARTIFACTS").strpath)
        dump_db_schema(db_url, dst_file=generated_schema_file)
        _assert_schema_files_equal(generated_schema_file, expected_schema_file)
def test_incorrect_dbfs_instantiation():
    artifact_repository_registry = ArtifactRepositoryRegistry()

    mock_dbfs_constructor = mock.Mock()
    artifact_repository_registry.register("dbfs", mock_dbfs_constructor)

    sql_store = SqlAlchemyStore("sqlite://", "./mlruns")

    with pytest.raises(mlflow.exceptions.MlflowException, match="must be an instance of RestStore"):
        artifact_repository_registry.get_artifact_repository(
            artifact_uri="dbfs://test-path", store=sql_store
        )

    mock_dbfs_constructor.assert_not_called()
def test_sqlalchemy_store_detects_schema_mismatch(tmpdir, db_url):  # pylint: disable=unused-argument
    def _assert_invalid_schema(engine):
        with pytest.raises(MlflowException) as ex:
            SqlAlchemyStore._verify_schema(engine)
            assert ex.message.contains("Detected out-of-date database schema.")

    # Initialize an empty database & verify that we detect a schema mismatch
    engine = sqlalchemy.create_engine(db_url)
    _assert_invalid_schema(engine)
    # Create legacy tables, verify schema is still out of date
    InitialBase.metadata.create_all(engine)
    _assert_invalid_schema(engine)
    # Run each migration. Until the last one, schema should be out of date
    config = _get_alembic_config(db_url)
    script = ScriptDirectory.from_config(config)
    revisions = list(script.walk_revisions())
    revisions.reverse()
    for rev in revisions[:-1]:
        command.upgrade(config, rev.revision)
        _assert_invalid_schema(engine)
    # Run migrations, schema verification should now pass
    invoke_cli_runner(mlflow.db.commands, ['upgrade', db_url])
    SqlAlchemyStore._verify_schema(engine)
예제 #8
0
def _get_store():
    from mlflow.server import BACKEND_STORE_URI_ENV_VAR, ARTIFACT_ROOT_ENV_VAR
    global _store
    if _store is None:
        store_dir = os.environ.get(BACKEND_STORE_URI_ENV_VAR, None)
        artifact_root = os.environ.get(ARTIFACT_ROOT_ENV_VAR, None)
        if _is_database_uri(store_dir):
            from mlflow.store.sqlalchemy_store import SqlAlchemyStore
            return SqlAlchemyStore(store_dir, artifact_root)
        elif _is_local_uri(store_dir):
            from mlflow.store.file_store import FileStore
            _store = FileStore(store_dir, artifact_root)
        else:
            raise MlflowException("Unexpected URI type '{}' for backend store. "
                                  "Expext local file or database type.".format(store_dir))
    return _store
예제 #9
0
파일: utils.py 프로젝트: trungtv/mlflow
def _get_store(store_uri=None):
    store_uri = store_uri if store_uri else get_tracking_uri()
    # Default: if URI hasn't been set, return a FileStore
    if store_uri is None:
        return FileStore()

    # Pattern-match on the URI
    if _is_db_uri(store_uri):
        return SqlAlchemyStore(store_uri)
    if _is_databricks_uri(store_uri):
        return _get_databricks_rest_store(store_uri)
    if _is_local_uri(store_uri):
        return _get_file_store(store_uri)
    if _is_http_uri(store_uri):
        return _get_rest_store(store_uri)

    raise Exception("Tracking URI must be a local filesystem URI of the form '%s...' or a "
                    "remote URI of the form '%s...'. Update the tracking URI via "
                    "mlflow.set_tracking_uri" % (_LOCAL_FS_URI_PREFIX, _REMOTE_URI_PREFIX))
예제 #10
0
파일: utils.py 프로젝트: zmoon111/mlflow
def _get_store(store_uri=None, artifact_uri=None):
    store_uri = store_uri if store_uri else get_tracking_uri()
    # Default: if URI hasn't been set, return a FileStore
    if store_uri is None:
        return _get_file_store(DEFAULT_LOCAL_FILE_AND_ARTIFACT_PATH)

    # Pattern-match on the URI
    if _is_database_uri(store_uri):
        from mlflow.store.sqlalchemy_store import SqlAlchemyStore
        if not artifact_uri:
            artifact_uri = DEFAULT_LOCAL_FILE_AND_ARTIFACT_PATH
        return SqlAlchemyStore(store_uri, artifact_uri)
    if _is_databricks_uri(store_uri):
        return _get_databricks_rest_store(store_uri)
    if _is_local_uri(store_uri):
        return _get_file_store(store_uri)
    if _is_http_uri(store_uri):
        return _get_rest_store(store_uri)

    raise Exception(
        "Tracking URI must be a local filesystem URI of the form '%s...' or a "
        "remote URI of the form '%s...'. Update the tracking URI via "
        "mlflow.set_tracking_uri" % (_LOCAL_FS_URI_PREFIX, _REMOTE_URI_PREFIX))
예제 #11
0
class TestSqlAlchemyStoreSqliteInMemory(unittest.TestCase):
    def setUp(self):
        self.store = SqlAlchemyStore(DB_URI)
        self.engine = sqlalchemy.create_engine(DB_URI)
        Session = sqlalchemy.orm.sessionmaker(bind=self.engine)
        self.session = Session()
        self.store.session = self.session
        self.store.engine = self.engine
        models.Base.metadata.create_all(self.engine)

    def tearDown(self):
        models.Base.metadata.drop_all(self.engine)

    def _experiment_factory(self, names):
        if type(names) is list:
            experiments = []
            for name in names:
                exp = self.store.create_experiment(name=name)
                experiments.append(exp)

            return experiments

        return self.store.create_experiment(name=names)

    def test_raise_duplicate_experiments(self):
        with self.assertRaises(Exception):
            self._experiment_factory(['test', 'test'])

    def test_raise_experiment_dont_exist(self):
        with self.assertRaises(Exception):
            self.store.get_experiment(experiment_id=100)

    def test_delete_experiment(self):
        experiments = self._experiment_factory(
            ['morty', 'rick', 'rick and morty'])
        exp = experiments[0]
        self.store.delete_experiment(exp.experiment_id)

        actual = self.session.query(models.SqlExperiment).get(
            exp.experiment_id)
        self.assertEqual(len(self.store.list_experiments()),
                         len(experiments) - 1)

        self.assertEqual(actual.lifecycle_stage,
                         entities.Experiment.DELETED_LIFECYCLE)

    def test_get_experiment(self):
        name = 'goku'
        run_data = self._experiment_factory(name)
        actual = self.store.get_experiment(run_data.experiment_id)
        self.assertEqual(actual.name, run_data.name)
        self.assertEqual(actual.experiment_id, run_data.experiment_id)

    def test_list_experiments(self):
        testnames = ['blue', 'red', 'green']

        run_data = self._experiment_factory(testnames)
        actual = self.store.list_experiments()

        self.assertEqual(len(run_data), len(actual))

        for exp in run_data:
            res = self.session.query(models.SqlExperiment).filter_by(
                experiment_id=exp.experiment_id).first()
            self.assertEqual(res.name, exp.name)
            self.assertEqual(res.experiment_id, exp.experiment_id)

    def test_create_experiments(self):
        result = self.session.query(models.SqlExperiment).all()
        self.assertEqual(len(result), 0)

        run_data = self.store.create_experiment(name='test experiment')
        result = self.session.query(models.SqlExperiment).all()
        self.assertEqual(len(result), 1)

        actual = result[0]

        self.assertEqual(actual.experiment_id, run_data.experiment_id)
        self.assertEqual(actual.name, run_data.name)

    def test_run_tag_model(self):
        run_data = models.SqlTag(run_uuid='tuuid', key='test', value='val')
        self.session.add(run_data)
        self.session.commit()
        tags = self.session.query(models.SqlTag).all()
        self.assertEqual(len(tags), 1)

        actual = tags[0].to_mlflow_entity()

        self.assertEqual(actual.value, run_data.value)
        self.assertEqual(actual.key, run_data.key)

    def test_metric_model(self):
        run_data = models.SqlMetric(run_uuid='testuid',
                                    key='accuracy',
                                    value=0.89)
        self.session.add(run_data)
        self.session.commit()
        metrics = self.session.query(models.SqlMetric).all()
        self.assertEqual(len(metrics), 1)

        actual = metrics[0].to_mlflow_entity()

        self.assertEqual(actual.value, run_data.value)
        self.assertEqual(actual.key, run_data.key)

    def test_param_model(self):
        run_data = models.SqlParam(run_uuid='test',
                                   key='accuracy',
                                   value='test param')
        self.session.add(run_data)
        self.session.commit()
        params = self.session.query(models.SqlParam).all()
        self.assertEqual(len(params), 1)

        actual = params[0].to_mlflow_entity()

        self.assertEqual(actual.value, run_data.value)
        self.assertEqual(actual.key, run_data.key)

    def test_run_needs_uuid(self):
        run = models.SqlRun()
        self.session.add(run)

        with self.assertRaises(sqlalchemy.exc.IntegrityError):
            warnings.simplefilter("ignore")
            with warnings.catch_warnings():
                self.session.commit()
            warnings.resetwarnings()

    def test_run_data_model(self):
        m1 = models.SqlMetric(key='accuracy', value=0.89)
        m2 = models.SqlMetric(key='recal', value=0.89)
        p1 = models.SqlParam(key='loss', value='test param')
        p2 = models.SqlParam(key='blue', value='test param')

        self.session.add_all([m1, m2, p1, p2])

        run_data = models.SqlRun(run_uuid=uuid.uuid4().hex)
        run_data.params.append(p1)
        run_data.params.append(p2)
        run_data.metrics.append(m1)
        run_data.metrics.append(m2)

        self.session.add(run_data)
        self.session.commit()

        run_datums = self.session.query(models.SqlRun).all()
        actual = run_datums[0]
        self.assertEqual(len(run_datums), 1)
        self.assertEqual(len(actual.params), 2)
        self.assertEqual(len(actual.metrics), 2)

    def test_run_info(self):
        experiment = self._experiment_factory('test exp')
        config = {
            'experiment_id': experiment.experiment_id,
            'name': 'test run',
            'user_id': 'Anderson',
            'run_uuid': 'test',
            'status': entities.RunInfo.ACTIVE_LIFECYCLE,
            'source_type': entities.SourceType.LOCAL,
            'source_name': 'Python application',
            'entry_point_name': 'main.py',
            'start_time': int(time.time()),
            'end_time': int(time.time()),
            'source_version': mlflow.__version__,
            'lifecycle_stage': entities.RunInfo.ACTIVE_LIFECYCLE,
            'artifact_uri': '//'
        }
        run = models.SqlRun(**config).to_mlflow_entity()

        for k, v in config.items():
            self.assertEqual(v, getattr(run.info, k))

    def _run_factory(self, name='test', experiment_id=None, config=None):
        m1 = models.SqlMetric(key='accuracy', value=0.89)
        m2 = models.SqlMetric(key='recal', value=0.89)
        p1 = models.SqlParam(key='loss', value='test param')
        p2 = models.SqlParam(key='blue', value='test param')

        if not experiment_id:
            experiment = self._experiment_factory('test exp')
            experiment_id = experiment.experiment_id

        config = {
            'experiment_id':
            experiment_id,
            'name':
            name,
            'user_id':
            'Anderson',
            'run_uuid':
            uuid.uuid4().hex,
            'status':
            entities.RunStatus.to_string(entities.RunStatus.SCHEDULED),
            'source_type':
            entities.SourceType.to_string(entities.SourceType.NOTEBOOK),
            'source_name':
            'Python application',
            'entry_point_name':
            'main.py',
            'start_time':
            int(time.time()),
            'end_time':
            int(time.time()),
            'source_version':
            mlflow.__version__,
            'lifecycle_stage':
            entities.RunInfo.ACTIVE_LIFECYCLE,
            'artifact_uri':
            '//'
        }

        run = models.SqlRun(**config)

        run.params.append(p1)
        run.params.append(p2)
        run.metrics.append(m1)
        run.metrics.append(m2)
        self.session.add(run)

        return run

    def test_create_run(self):
        expected = self._run_factory()
        name = 'booyya'
        expected.tags.append(models.SqlTag(key='3', value='4'))
        expected.tags.append(models.SqlTag(key='1', value='2'))

        tags = [t.to_mlflow_entity() for t in expected.tags]
        actual = self.store.create_run(
            expected.experiment_id, expected.user_id, name,
            entities.SourceType.from_string(expected.source_type),
            expected.source_name, expected.entry_point_name,
            expected.start_time, expected.source_version, tags, None)

        self.assertEqual(actual.info.experiment_id, expected.experiment_id)
        self.assertEqual(actual.info.user_id, expected.user_id)
        self.assertEqual(actual.info.name, name)
        self.assertEqual(actual.info.source_type, expected.source_type)
        self.assertEqual(actual.info.source_name, expected.source_name)
        self.assertEqual(actual.info.source_version, expected.source_version)
        self.assertEqual(actual.info.entry_point_name,
                         expected.entry_point_name)
        self.assertEqual(actual.info.start_time, expected.start_time)
        self.assertEqual(len(actual.data.tags), 3)

        name_tag = models.SqlTag(key='mlflow.runName',
                                 value=name).to_mlflow_entity()
        self.assertListEqual(actual.data.tags, tags + [name_tag])

    def test_create_run_with_parent_id(self):
        expected = self._run_factory()
        name = 'booyya'
        expected.tags.append(models.SqlTag(key='3', value='4'))
        expected.tags.append(models.SqlTag(key='1', value='2'))

        tags = [t.to_mlflow_entity() for t in expected.tags]
        actual = self.store.create_run(
            expected.experiment_id, expected.user_id, name,
            entities.SourceType.from_string(expected.source_type),
            expected.source_name, expected.entry_point_name,
            expected.start_time, expected.source_version, tags,
            "parent_uuid_5")

        self.assertEqual(actual.info.experiment_id, expected.experiment_id)
        self.assertEqual(actual.info.user_id, expected.user_id)
        self.assertEqual(actual.info.name, name)
        self.assertEqual(actual.info.source_type, expected.source_type)
        self.assertEqual(actual.info.source_name, expected.source_name)
        self.assertEqual(actual.info.source_version, expected.source_version)
        self.assertEqual(actual.info.entry_point_name,
                         expected.entry_point_name)
        self.assertEqual(actual.info.start_time, expected.start_time)
        self.assertEqual(len(actual.data.tags), 4)

        name_tag = models.SqlTag(key='mlflow.runName',
                                 value=name).to_mlflow_entity()
        parent_id_tag = models.SqlTag(
            key='mlflow.parentRunId',
            value='parent_uuid_5').to_mlflow_entity()
        self.assertListEqual(actual.data.tags,
                             tags + [parent_id_tag, name_tag])

    def test_to_mlflow_entity(self):
        run = self._run_factory()
        run = run.to_mlflow_entity()

        self.assertIsInstance(run.info, entities.RunInfo)
        self.assertIsInstance(run.data, entities.RunData)

        for metric in run.data.metrics:
            self.assertIsInstance(metric, entities.Metric)

        for param in run.data.params:
            self.assertIsInstance(param, entities.Param)

        for tag in run.data.tags:
            self.assertIsInstance(tag, entities.RunTag)

    def test_delete_run(self):
        run = self._run_factory()
        self.session.commit()

        run_uuid = run.run_uuid
        self.store.delete_run(run_uuid)
        actual = self.session.query(
            models.SqlRun).filter_by(run_uuid=run_uuid).first()
        self.assertEqual(actual.lifecycle_stage,
                         entities.RunInfo.DELETED_LIFECYCLE)

        deleted_run = self.store.get_run(run_uuid)
        self.assertEqual(actual.run_uuid, deleted_run.info.run_uuid)

    def test_log_metric(self):
        run = self._run_factory()

        self.session.commit()

        tkey = 'blahmetric'
        tval = 100.0
        metric = entities.Metric(tkey, tval, int(time.time()))
        metric2 = entities.Metric(tkey, tval, int(time.time()) + 2)
        self.store.log_metric(run.run_uuid, metric)
        self.store.log_metric(run.run_uuid, metric2)

        actual = self.session.query(models.SqlMetric).filter_by(key=tkey,
                                                                value=tval)

        self.assertIsNotNone(actual)

        run = self.store.get_run(run.run_uuid)

        self.assertEqual(4, len(run.data.metrics))
        found = False
        for m in run.data.metrics:
            if m.key == tkey and m.value == tval:
                found = True

        self.assertTrue(found)

    def test_log_metric_uniqueness(self):
        run = self._run_factory()

        self.session.commit()

        tkey = 'blahmetric'
        tval = 100.0
        metric = entities.Metric(tkey, tval, int(time.time()))
        metric2 = entities.Metric(tkey, 1.02, int(time.time()))
        self.store.log_metric(run.run_uuid, metric)

        with self.assertRaises(MlflowException):
            self.store.log_metric(run.run_uuid, metric2)

    def test_log_param(self):
        run = self._run_factory('test')

        self.session.commit()

        tkey = 'blahmetric'
        tval = '100.0'
        param = entities.Param(tkey, tval)
        param2 = entities.Param('new param', 'new key')
        self.store.log_param(run.run_uuid, param)
        self.store.log_param(run.run_uuid, param2)

        actual = self.session.query(models.SqlParam).filter_by(key=tkey,
                                                               value=tval)
        self.assertIsNotNone(actual)

        run = self.store.get_run(run.run_uuid)
        self.assertEqual(4, len(run.data.params))

        found = False
        for m in run.data.params:
            if m.key == tkey and m.value == tval:
                found = True

        self.assertTrue(found)

    def test_log_param_uniqueness(self):
        run = self._run_factory('test')

        self.session.commit()

        tkey = 'blahmetric'
        tval = '100.0'
        param = entities.Param(tkey, tval)
        param2 = entities.Param(tkey, 'newval')
        self.store.log_param(run.run_uuid, param)

        with self.assertRaises(MlflowException):
            self.store.log_param(run.run_uuid, param2)

    def test_set_tag(self):
        run = self._run_factory('test')

        self.session.commit()

        tkey = 'test tag'
        tval = 'a boogie'
        tag = entities.RunTag(tkey, tval)
        self.store.set_tag(run.run_uuid, tag)

        actual = self.session.query(models.SqlTag).filter_by(key=tkey,
                                                             value=tval)

        self.assertIsNotNone(actual)

        run = self.store.get_run(run.run_uuid)

        found = False
        for m in run.data.tags:
            if m.key == tkey and m.value == tval:
                found = True

        self.assertTrue(found)

    def test_get_metric(self):
        run = self._run_factory('test')
        self.session.commit()

        for expected in run.metrics:
            actual = self.store.get_metric(run.run_uuid, expected.key)
            self.assertEqual(expected.value, actual)

    def test_get_param(self):
        run = self._run_factory('test')
        self.session.commit()

        for expected in run.params:
            actual = self.store.get_param(run.run_uuid, expected.key)
            self.assertEqual(expected.value, actual)

    def test_get_metric_history(self):
        run = self._run_factory('test')
        self.session.commit()
        key = 'test'
        expected = [
            models.SqlMetric(key=key, value=0.6,
                             timestamp=1).to_mlflow_entity(),
            models.SqlMetric(key=key, value=0.7,
                             timestamp=2).to_mlflow_entity()
        ]

        for metric in expected:
            self.store.log_metric(run.run_uuid, metric)

        actual = self.store.get_metric_history(run.run_uuid, key)

        self.assertEqual(len(expected), len(actual))

    def test_list_run_infos(self):
        exp1 = self._experiment_factory('test_exp')
        runs = [
            self._run_factory('t1', exp1.experiment_id).to_mlflow_entity(),
            self._run_factory('t2', exp1.experiment_id).to_mlflow_entity(),
        ]

        expected = [run.info for run in runs]

        actual = self.store.list_run_infos(exp1.experiment_id)

        self.assertEqual(len(expected), len(actual))

    def test_rename_experiment(self):
        new_name = 'new name'
        experiment = self._experiment_factory('test name')
        self.store.rename_experiment(experiment.experiment_id, new_name)

        renamed_experiment = self.store.get_experiment(
            experiment.experiment_id)

        self.assertEqual(renamed_experiment.name, new_name)

    def test_update_run_info(self):
        run = self._run_factory()
        new_status = entities.RunStatus.FINISHED
        endtime = int(time.time())

        actual = self.store.update_run_info(run.run_uuid, new_status, endtime)

        self.assertEqual(actual.status,
                         entities.RunStatus.to_string(new_status))
        self.assertEqual(actual.end_time, endtime)

    def test_restore_experiment(self):
        exp = self._experiment_factory('helloexp')
        self.assertEqual(exp.lifecycle_stage,
                         entities.Experiment.ACTIVE_LIFECYCLE)

        experiment_id = exp.experiment_id
        self.store.delete_experiment(experiment_id)

        deleted = self.store.get_experiment(experiment_id)
        self.assertEqual(deleted.experiment_id, experiment_id)
        self.assertEqual(deleted.lifecycle_stage,
                         entities.Experiment.DELETED_LIFECYCLE)

        self.store.restore_experiment(exp.experiment_id)
        restored = self.store.get_experiment(exp.experiment_id)
        self.assertEqual(restored.experiment_id, experiment_id)
        self.assertEqual(restored.lifecycle_stage,
                         entities.Experiment.ACTIVE_LIFECYCLE)

    def test_restore_run(self):
        run = self._run_factory()
        self.assertEqual(run.lifecycle_stage,
                         entities.RunInfo.ACTIVE_LIFECYCLE)

        run_uuid = run.run_uuid
        self.store.delete_run(run_uuid)

        deleted = self.store.get_run(run_uuid)
        self.assertEqual(deleted.info.run_uuid, run_uuid)
        self.assertEqual(deleted.info.lifecycle_stage,
                         entities.RunInfo.DELETED_LIFECYCLE)

        self.store.restore_run(run_uuid)
        restored = self.store.get_run(run_uuid)
        self.assertEqual(restored.info.run_uuid, run_uuid)
        self.assertEqual(restored.info.lifecycle_stage,
                         entities.RunInfo.ACTIVE_LIFECYCLE)
예제 #12
0
class TestSqlAlchemyStoreSqliteInMemory(unittest.TestCase):
    def setUp(self):
        self.maxDiff = None  # print all differences on assert failures
        self.store = SqlAlchemyStore(DB_URI)
        self.engine = sqlalchemy.create_engine(DB_URI)
        Session = sqlalchemy.orm.sessionmaker(bind=self.engine)
        self.session = Session()
        self.store.session = self.session
        self.store.engine = self.engine
        models.Base.metadata.create_all(self.engine)

    def tearDown(self):
        models.Base.metadata.drop_all(self.engine)

    def _experiment_factory(self, names):
        if type(names) is list:
            experiments = []
            for name in names:
                exp = self.store.create_experiment(name=name)
                experiments.append(exp)

            return experiments

        return self.store.create_experiment(name=names)

    def test_raise_duplicate_experiments(self):
        with self.assertRaises(Exception):
            self._experiment_factory(['test', 'test'])

    def test_raise_experiment_dont_exist(self):
        with self.assertRaises(Exception):
            self.store.get_experiment(experiment_id=100)

    def test_delete_experiment(self):
        experiments = self._experiment_factory(
            ['morty', 'rick', 'rick and morty'])
        exp = experiments[0]
        self.store.delete_experiment(exp.experiment_id)

        actual = self.session.query(models.SqlExperiment).get(
            exp.experiment_id)
        self.assertEqual(len(self.store.list_experiments()),
                         len(experiments) - 1)

        self.assertEqual(actual.lifecycle_stage,
                         entities.LifecycleStage.DELETED)

    def test_get_experiment(self):
        name = 'goku'
        run_data = self._experiment_factory(name)
        actual = self.store.get_experiment(run_data.experiment_id)
        self.assertEqual(actual.name, run_data.name)
        self.assertEqual(actual.experiment_id, run_data.experiment_id)

    def test_list_experiments(self):
        testnames = ['blue', 'red', 'green']

        run_data = self._experiment_factory(testnames)
        actual = self.store.list_experiments()

        self.assertEqual(len(run_data), len(actual))

        for exp in run_data:
            res = self.session.query(models.SqlExperiment).filter_by(
                experiment_id=exp.experiment_id).first()
            self.assertEqual(res.name, exp.name)
            self.assertEqual(res.experiment_id, exp.experiment_id)

    def test_create_experiments(self):
        result = self.session.query(models.SqlExperiment).all()
        self.assertEqual(len(result), 0)

        run_data = self.store.create_experiment(name='test experiment')
        result = self.session.query(models.SqlExperiment).all()
        self.assertEqual(len(result), 1)

        actual = result[0]

        self.assertEqual(actual.experiment_id, run_data.experiment_id)
        self.assertEqual(actual.name, run_data.name)

    def test_run_tag_model(self):
        run_data = models.SqlTag(run_uuid='tuuid', key='test', value='val')
        self.session.add(run_data)
        self.session.commit()
        tags = self.session.query(models.SqlTag).all()
        self.assertEqual(len(tags), 1)

        actual = tags[0].to_mlflow_entity()

        self.assertEqual(actual.value, run_data.value)
        self.assertEqual(actual.key, run_data.key)

    def test_metric_model(self):
        run_data = models.SqlMetric(run_uuid='testuid',
                                    key='accuracy',
                                    value=0.89)
        self.session.add(run_data)
        self.session.commit()
        metrics = self.session.query(models.SqlMetric).all()
        self.assertEqual(len(metrics), 1)

        actual = metrics[0].to_mlflow_entity()

        self.assertEqual(actual.value, run_data.value)
        self.assertEqual(actual.key, run_data.key)

    def test_param_model(self):
        run_data = models.SqlParam(run_uuid='test',
                                   key='accuracy',
                                   value='test param')
        self.session.add(run_data)
        self.session.commit()
        params = self.session.query(models.SqlParam).all()
        self.assertEqual(len(params), 1)

        actual = params[0].to_mlflow_entity()

        self.assertEqual(actual.value, run_data.value)
        self.assertEqual(actual.key, run_data.key)

    def test_run_needs_uuid(self):
        run = models.SqlRun()
        self.session.add(run)

        with self.assertRaises(sqlalchemy.exc.IntegrityError):
            warnings.simplefilter("ignore")
            with warnings.catch_warnings():
                self.session.commit()
            warnings.resetwarnings()

    def test_run_data_model(self):
        m1 = models.SqlMetric(key='accuracy', value=0.89)
        m2 = models.SqlMetric(key='recal', value=0.89)
        p1 = models.SqlParam(key='loss', value='test param')
        p2 = models.SqlParam(key='blue', value='test param')

        self.session.add_all([m1, m2, p1, p2])

        run_data = models.SqlRun(run_uuid=uuid.uuid4().hex)
        run_data.params.append(p1)
        run_data.params.append(p2)
        run_data.metrics.append(m1)
        run_data.metrics.append(m2)

        self.session.add(run_data)
        self.session.commit()

        run_datums = self.session.query(models.SqlRun).all()
        actual = run_datums[0]
        self.assertEqual(len(run_datums), 1)
        self.assertEqual(len(actual.params), 2)
        self.assertEqual(len(actual.metrics), 2)

    def test_run_info(self):
        experiment = self._experiment_factory('test exp')
        config = {
            'experiment_id': experiment.experiment_id,
            'name': 'test run',
            'user_id': 'Anderson',
            'run_uuid': 'test',
            'status': entities.LifecycleStage.ACTIVE,
            'source_type': entities.SourceType.LOCAL,
            'source_name': 'Python application',
            'entry_point_name': 'main.py',
            'start_time': int(time.time()),
            'end_time': int(time.time()),
            'source_version': mlflow.__version__,
            'lifecycle_stage': entities.LifecycleStage.ACTIVE,
            'artifact_uri': '//'
        }
        run = models.SqlRun(**config).to_mlflow_entity()

        for k, v in config.items():
            self.assertEqual(v, getattr(run.info, k))

    def _run_factory(self, name='test', experiment_id=None, config=None):
        m1 = models.SqlMetric(key='accuracy', value=0.89)
        m2 = models.SqlMetric(key='recal', value=0.89)
        p1 = models.SqlParam(key='loss', value='test param')
        p2 = models.SqlParam(key='blue', value='test param')

        if not experiment_id:
            experiment = self._experiment_factory('test exp')
            experiment_id = experiment.experiment_id

        config = {
            'experiment_id':
            experiment_id,
            'name':
            name,
            'user_id':
            'Anderson',
            'run_uuid':
            uuid.uuid4().hex,
            'status':
            entities.RunStatus.to_string(entities.RunStatus.SCHEDULED),
            'source_type':
            entities.SourceType.to_string(entities.SourceType.NOTEBOOK),
            'source_name':
            'Python application',
            'entry_point_name':
            'main.py',
            'start_time':
            int(time.time()),
            'end_time':
            int(time.time()),
            'source_version':
            mlflow.__version__,
            'lifecycle_stage':
            entities.LifecycleStage.ACTIVE,
            'artifact_uri':
            '//'
        }

        run = models.SqlRun(**config)

        run.params.append(p1)
        run.params.append(p2)
        run.metrics.append(m1)
        run.metrics.append(m2)
        self.session.add(run)

        return run

    def test_create_run(self):
        expected = self._run_factory()
        name = 'booyya'
        expected.tags.append(models.SqlTag(key='3', value='4'))
        expected.tags.append(models.SqlTag(key='1', value='2'))

        tags = [t.to_mlflow_entity() for t in expected.tags]
        actual = self.store.create_run(
            expected.experiment_id, expected.user_id, name,
            entities.SourceType.from_string(expected.source_type),
            expected.source_name, expected.entry_point_name,
            expected.start_time, expected.source_version, tags, None)

        self.assertEqual(actual.info.experiment_id, expected.experiment_id)
        self.assertEqual(actual.info.user_id, expected.user_id)
        self.assertEqual(actual.info.name, name)
        self.assertEqual(actual.info.source_type, expected.source_type)
        self.assertEqual(actual.info.source_name, expected.source_name)
        self.assertEqual(actual.info.source_version, expected.source_version)
        self.assertEqual(actual.info.entry_point_name,
                         expected.entry_point_name)
        self.assertEqual(actual.info.start_time, expected.start_time)
        self.assertEqual(len(actual.data.tags), 3)

        name_tag = models.SqlTag(key='mlflow.runName',
                                 value=name).to_mlflow_entity()
        self.assertListEqual(actual.data.tags, tags + [name_tag])

    def test_create_run_with_parent_id(self):
        expected = self._run_factory()
        name = 'booyya'
        expected.tags.append(models.SqlTag(key='3', value='4'))
        expected.tags.append(models.SqlTag(key='1', value='2'))

        tags = [t.to_mlflow_entity() for t in expected.tags]
        actual = self.store.create_run(
            expected.experiment_id, expected.user_id, name,
            entities.SourceType.from_string(expected.source_type),
            expected.source_name, expected.entry_point_name,
            expected.start_time, expected.source_version, tags,
            "parent_uuid_5")

        self.assertEqual(actual.info.experiment_id, expected.experiment_id)
        self.assertEqual(actual.info.user_id, expected.user_id)
        self.assertEqual(actual.info.name, name)
        self.assertEqual(actual.info.source_type, expected.source_type)
        self.assertEqual(actual.info.source_name, expected.source_name)
        self.assertEqual(actual.info.source_version, expected.source_version)
        self.assertEqual(actual.info.entry_point_name,
                         expected.entry_point_name)
        self.assertEqual(actual.info.start_time, expected.start_time)
        self.assertEqual(len(actual.data.tags), 4)

        name_tag = models.SqlTag(key='mlflow.runName',
                                 value=name).to_mlflow_entity()
        parent_id_tag = models.SqlTag(
            key='mlflow.parentRunId',
            value='parent_uuid_5').to_mlflow_entity()
        self.assertListEqual(actual.data.tags,
                             tags + [parent_id_tag, name_tag])

    def test_to_mlflow_entity(self):
        run = self._run_factory()
        run = run.to_mlflow_entity()

        self.assertIsInstance(run.info, entities.RunInfo)
        self.assertIsInstance(run.data, entities.RunData)

        for metric in run.data.metrics:
            self.assertIsInstance(metric, entities.Metric)

        for param in run.data.params:
            self.assertIsInstance(param, entities.Param)

        for tag in run.data.tags:
            self.assertIsInstance(tag, entities.RunTag)

    def test_delete_run(self):
        run = self._run_factory()
        self.session.commit()

        run_uuid = run.run_uuid
        self.store.delete_run(run_uuid)
        actual = self.session.query(
            models.SqlRun).filter_by(run_uuid=run_uuid).first()
        self.assertEqual(actual.lifecycle_stage,
                         entities.LifecycleStage.DELETED)

        deleted_run = self.store.get_run(run_uuid)
        self.assertEqual(actual.run_uuid, deleted_run.info.run_uuid)

    def test_log_metric(self):
        run = self._run_factory()

        self.session.commit()

        tkey = 'blahmetric'
        tval = 100.0
        metric = entities.Metric(tkey, tval, int(time.time()))
        metric2 = entities.Metric(tkey, tval, int(time.time()) + 2)
        self.store.log_metric(run.run_uuid, metric)
        self.store.log_metric(run.run_uuid, metric2)

        actual = self.session.query(models.SqlMetric).filter_by(key=tkey,
                                                                value=tval)

        self.assertIsNotNone(actual)

        run = self.store.get_run(run.run_uuid)

        # SQL store _get_run method returns full history of recorded metrics.
        # Should return duplicates as well
        # MLflow RunData contains only the last reported values for metrics.
        sql_run_metrics = self.store._get_run(run.info.run_uuid,
                                              ViewType.ALL).metrics
        self.assertEqual(4, len(sql_run_metrics))
        self.assertEqual(3, len(run.data.metrics))

        found = False
        for m in run.data.metrics:
            if m.key == tkey and m.value == tval:
                found = True

        self.assertTrue(found)

    def test_log_metric_uniqueness(self):
        run = self._run_factory()

        self.session.commit()

        tkey = 'blahmetric'
        tval = 100.0
        metric = entities.Metric(tkey, tval, int(time.time()))
        metric2 = entities.Metric(tkey, 1.02, int(time.time()))
        self.store.log_metric(run.run_uuid, metric)

        with self.assertRaises(MlflowException):
            self.store.log_metric(run.run_uuid, metric2)

    def test_log_param(self):
        run = self._run_factory('test')

        self.session.commit()

        tkey = 'blahmetric'
        tval = '100.0'
        param = entities.Param(tkey, tval)
        param2 = entities.Param('new param', 'new key')
        self.store.log_param(run.run_uuid, param)
        self.store.log_param(run.run_uuid, param2)

        actual = self.session.query(models.SqlParam).filter_by(key=tkey,
                                                               value=tval)
        self.assertIsNotNone(actual)

        run = self.store.get_run(run.run_uuid)
        self.assertEqual(4, len(run.data.params))

        found = False
        for m in run.data.params:
            if m.key == tkey and m.value == tval:
                found = True

        self.assertTrue(found)

    def test_log_param_uniqueness(self):
        run = self._run_factory('test')

        self.session.commit()

        tkey = 'blahmetric'
        tval = '100.0'
        param = entities.Param(tkey, tval)
        param2 = entities.Param(tkey, 'newval')
        self.store.log_param(run.run_uuid, param)

        with self.assertRaises(MlflowException):
            self.store.log_param(run.run_uuid, param2)

    def test_set_tag(self):
        run = self._run_factory('test')

        self.session.commit()

        tkey = 'test tag'
        tval = 'a boogie'
        tag = entities.RunTag(tkey, tval)
        self.store.set_tag(run.run_uuid, tag)

        actual = self.session.query(models.SqlTag).filter_by(key=tkey,
                                                             value=tval)

        self.assertIsNotNone(actual)

        run = self.store.get_run(run.run_uuid)

        found = False
        for m in run.data.tags:
            if m.key == tkey and m.value == tval:
                found = True

        self.assertTrue(found)

    def test_get_metric(self):
        run = self._run_factory('test')
        self.session.commit()

        for expected in run.metrics:
            actual = self.store.get_metric(run.run_uuid, expected.key)
            self.assertEqual(expected.value, actual)

    def test_get_param(self):
        run = self._run_factory('test')
        self.session.commit()

        for expected in run.params:
            actual = self.store.get_param(run.run_uuid, expected.key)
            self.assertEqual(expected.value, actual)

    def test_get_metric_history(self):
        run = self._run_factory('test')
        self.session.commit()
        key = 'test'
        expected = [
            models.SqlMetric(key=key, value=0.6,
                             timestamp=1).to_mlflow_entity(),
            models.SqlMetric(key=key, value=0.7,
                             timestamp=2).to_mlflow_entity()
        ]

        for metric in expected:
            self.store.log_metric(run.run_uuid, metric)

        actual = self.store.get_metric_history(run.run_uuid, key)

        self.assertSequenceEqual([m.value for m in expected], actual)

    def test_list_run_infos(self):
        exp1 = self._experiment_factory('test_exp')
        r1 = self._run_factory('t1', exp1.experiment_id).run_uuid
        r2 = self._run_factory('t2', exp1.experiment_id).run_uuid

        def _runs(experiment_id, view_type):
            return [
                r.run_uuid
                for r in self.store.list_run_infos(experiment_id, view_type)
            ]

        exp_id = exp1.experiment_id
        self.assertSequenceEqual([r1, r2], _runs(exp_id, ViewType.ALL))
        self.assertSequenceEqual([r1, r2], _runs(exp_id, ViewType.ACTIVE_ONLY))
        self.assertEqual(0, len(_runs(exp_id, ViewType.DELETED_ONLY)))

        self.store.delete_run(r1)
        self.assertSequenceEqual([r1, r2], _runs(exp_id, ViewType.ALL))
        self.assertSequenceEqual([r2], _runs(exp_id, ViewType.ACTIVE_ONLY))
        self.assertSequenceEqual([r1], _runs(exp_id, ViewType.DELETED_ONLY))

    def test_rename_experiment(self):
        new_name = 'new name'
        experiment = self._experiment_factory('test name')
        self.store.rename_experiment(experiment.experiment_id, new_name)

        renamed_experiment = self.store.get_experiment(
            experiment.experiment_id)

        self.assertEqual(renamed_experiment.name, new_name)

    def test_update_run_info(self):
        run = self._run_factory()
        new_status = entities.RunStatus.FINISHED
        endtime = int(time.time())

        actual = self.store.update_run_info(run.run_uuid, new_status, endtime)

        self.assertEqual(actual.status,
                         entities.RunStatus.to_string(new_status))
        self.assertEqual(actual.end_time, endtime)

    def test_restore_experiment(self):
        exp = self._experiment_factory('helloexp')
        self.assertEqual(exp.lifecycle_stage, entities.LifecycleStage.ACTIVE)

        experiment_id = exp.experiment_id
        self.store.delete_experiment(experiment_id)

        deleted = self.store.get_experiment(experiment_id)
        self.assertEqual(deleted.experiment_id, experiment_id)
        self.assertEqual(deleted.lifecycle_stage,
                         entities.LifecycleStage.DELETED)

        self.store.restore_experiment(exp.experiment_id)
        restored = self.store.get_experiment(exp.experiment_id)
        self.assertEqual(restored.experiment_id, experiment_id)
        self.assertEqual(restored.lifecycle_stage,
                         entities.LifecycleStage.ACTIVE)

    def test_restore_run(self):
        run = self._run_factory()
        self.assertEqual(run.lifecycle_stage, entities.LifecycleStage.ACTIVE)

        run_uuid = run.run_uuid
        self.store.delete_run(run_uuid)

        deleted = self.store.get_run(run_uuid)
        self.assertEqual(deleted.info.run_uuid, run_uuid)
        self.assertEqual(deleted.info.lifecycle_stage,
                         entities.LifecycleStage.DELETED)

        self.store.restore_run(run_uuid)
        restored = self.store.get_run(run_uuid)
        self.assertEqual(restored.info.run_uuid, run_uuid)
        self.assertEqual(restored.info.lifecycle_stage,
                         entities.LifecycleStage.ACTIVE)

    # Tests for Search API
    def _search(self,
                experiment_id,
                metrics_expressions=None,
                param_expressions=None,
                run_view_type=ViewType.ALL):
        conditions = (metrics_expressions or []) + (param_expressions or [])
        return [
            r.info.run_uuid for r in self.store.search_runs(
                [experiment_id], conditions, run_view_type)
        ]

    def _param_expression(self, key, comparator, val):
        expr = SearchExpression()
        expr.parameter.key = key
        expr.parameter.string.comparator = comparator
        expr.parameter.string.value = val
        return expr

    def _metric_expression(self, key, comparator, val):
        expr = SearchExpression()
        expr.metric.key = key
        expr.metric.double.comparator = comparator
        expr.metric.double.value = val
        return expr

    def test_search_vanilla(self):
        exp = self._experiment_factory('search_vanilla').experiment_id
        runs = [self._run_factory('r_%d' % r, exp).run_uuid for r in range(3)]

        self.assertSequenceEqual(runs,
                                 self._search(exp, run_view_type=ViewType.ALL))
        self.assertSequenceEqual(
            runs, self._search(exp, run_view_type=ViewType.ACTIVE_ONLY))
        self.assertSequenceEqual([],
                                 self._search(
                                     exp, run_view_type=ViewType.DELETED_ONLY))

        first = runs[0]

        self.store.delete_run(first)
        self.assertSequenceEqual(runs,
                                 self._search(exp, run_view_type=ViewType.ALL))
        self.assertSequenceEqual(
            runs[1:], self._search(exp, run_view_type=ViewType.ACTIVE_ONLY))
        self.assertSequenceEqual([first],
                                 self._search(
                                     exp, run_view_type=ViewType.DELETED_ONLY))

        self.store.restore_run(first)
        self.assertSequenceEqual(runs,
                                 self._search(exp, run_view_type=ViewType.ALL))
        self.assertSequenceEqual(
            runs, self._search(exp, run_view_type=ViewType.ACTIVE_ONLY))
        self.assertSequenceEqual([],
                                 self._search(
                                     exp, run_view_type=ViewType.DELETED_ONLY))

    def test_search_params(self):
        experiment_id = self._experiment_factory('search_params').experiment_id
        r1 = self._run_factory('r1',
                               experiment_id).to_mlflow_entity().info.run_uuid
        r2 = self._run_factory('r2',
                               experiment_id).to_mlflow_entity().info.run_uuid

        self.store.log_param(r1, entities.Param('generic_param', 'p_val'))
        self.store.log_param(r2, entities.Param('generic_param', 'p_val'))

        self.store.log_param(r1, entities.Param('generic_2', 'some value'))
        self.store.log_param(r2, entities.Param('generic_2', 'another value'))

        self.store.log_param(r1, entities.Param('p_a', 'abc'))
        self.store.log_param(r2, entities.Param('p_b', 'ABC'))

        # test search returns both runs
        expr = self._param_expression("generic_param", "=", "p_val")
        self.assertSequenceEqual([r1, r2],
                                 self._search(experiment_id,
                                              param_expressions=[expr]))

        # test search returns appropriate run (same key different values per run)
        expr = self._param_expression("generic_2", "=", "some value")
        self.assertSequenceEqual([r1],
                                 self._search(experiment_id,
                                              param_expressions=[expr]))
        expr = self._param_expression("generic_2", "=", "another value")
        self.assertSequenceEqual([r2],
                                 self._search(experiment_id,
                                              param_expressions=[expr]))

        expr = self._param_expression("generic_param", "=", "wrong_val")
        self.assertSequenceEqual([],
                                 self._search(experiment_id,
                                              param_expressions=[expr]))

        expr = self._param_expression("generic_param", "!=", "p_val")
        self.assertSequenceEqual([],
                                 self._search(experiment_id,
                                              param_expressions=[expr]))

        expr = self._param_expression("generic_param", "!=", "wrong_val")
        self.assertSequenceEqual([r1, r2],
                                 self._search(experiment_id,
                                              param_expressions=[expr]))
        expr = self._param_expression("generic_2", "!=", "wrong_val")
        self.assertSequenceEqual([r1, r2],
                                 self._search(experiment_id,
                                              param_expressions=[expr]))

        expr = self._param_expression("p_a", "=", "abc")
        self.assertSequenceEqual([r1],
                                 self._search(experiment_id,
                                              param_expressions=[expr]))

        expr = self._param_expression("p_b", "=", "ABC")
        self.assertSequenceEqual([r2],
                                 self._search(experiment_id,
                                              param_expressions=[expr]))

    def test_search_metrics(self):
        experiment_id = self._experiment_factory('search_params').experiment_id
        r1 = self._run_factory('r1',
                               experiment_id).to_mlflow_entity().info.run_uuid
        r2 = self._run_factory('r2',
                               experiment_id).to_mlflow_entity().info.run_uuid

        self.store.log_metric(r1, entities.Metric("common", 1.0, 1))
        self.store.log_metric(r2, entities.Metric("common", 1.0, 1))

        self.store.log_metric(r1, entities.Metric("measure_a", 1.0, 1))
        self.store.log_metric(r2, entities.Metric("measure_a", 200.0, 2))
        self.store.log_metric(r2, entities.Metric("measure_a", 400.0, 3))

        self.store.log_metric(r1, entities.Metric("m_a", 2.0, 2))
        self.store.log_metric(r2, entities.Metric("m_b", 3.0, 2))
        self.store.log_metric(r2, entities.Metric("m_b", 4.0,
                                                  8))  # this is last timestamp
        self.store.log_metric(r2, entities.Metric("m_b", 8.0, 3))

        expr = self._metric_expression("common", "=", 1.0)
        self.assertSequenceEqual([r1, r2],
                                 self._search(experiment_id,
                                              param_expressions=[expr]))

        expr = self._metric_expression("common", ">", 0.0)
        self.assertSequenceEqual([r1, r2],
                                 self._search(experiment_id,
                                              param_expressions=[expr]))

        expr = self._metric_expression("common", ">=", 0.0)
        self.assertSequenceEqual([r1, r2],
                                 self._search(experiment_id,
                                              param_expressions=[expr]))

        expr = self._metric_expression("common", "<", 4.0)
        self.assertSequenceEqual([r1, r2],
                                 self._search(experiment_id,
                                              param_expressions=[expr]))

        expr = self._metric_expression("common", "<=", 4.0)
        self.assertSequenceEqual([r1, r2],
                                 self._search(experiment_id,
                                              param_expressions=[expr]))

        expr = self._metric_expression("common", "!=", 1.0)
        self.assertSequenceEqual([],
                                 self._search(experiment_id,
                                              param_expressions=[expr]))

        expr = self._metric_expression("common", ">=", 3.0)
        self.assertSequenceEqual([],
                                 self._search(experiment_id,
                                              param_expressions=[expr]))

        expr = self._metric_expression("common", "<=", 0.75)
        self.assertSequenceEqual([],
                                 self._search(experiment_id,
                                              param_expressions=[expr]))

        # tests for same metric name across runs with different values and timestamps
        expr = self._metric_expression("measure_a", ">", 0.0)
        self.assertSequenceEqual([r1, r2],
                                 self._search(experiment_id,
                                              param_expressions=[expr]))

        expr = self._metric_expression("measure_a", "<", 50.0)
        self.assertSequenceEqual([r1],
                                 self._search(experiment_id,
                                              param_expressions=[expr]))

        expr = self._metric_expression("measure_a", "<", 1000.0)
        self.assertSequenceEqual([r1, r2],
                                 self._search(experiment_id,
                                              param_expressions=[expr]))

        expr = self._metric_expression("measure_a", "!=", -12.0)
        self.assertSequenceEqual([r1, r2],
                                 self._search(experiment_id,
                                              param_expressions=[expr]))

        expr = self._metric_expression("measure_a", ">", 50.0)
        self.assertSequenceEqual([r2],
                                 self._search(experiment_id,
                                              param_expressions=[expr]))

        expr = self._metric_expression("measure_a", "=", 1.0)
        self.assertSequenceEqual([r1],
                                 self._search(experiment_id,
                                              param_expressions=[expr]))

        expr = self._metric_expression("measure_a", "=", 400.0)
        self.assertSequenceEqual([r2],
                                 self._search(experiment_id,
                                              param_expressions=[expr]))

        # test search with unique metric keys
        expr = self._metric_expression("m_a", ">", 1.0)
        self.assertSequenceEqual([r1],
                                 self._search(experiment_id,
                                              param_expressions=[expr]))

        expr = self._metric_expression("m_b", ">", 1.0)
        self.assertSequenceEqual([r2],
                                 self._search(experiment_id,
                                              param_expressions=[expr]))

        # there is a recorded metric this threshold but not last timestamp
        expr = self._metric_expression("m_b", ">", 5.0)
        self.assertSequenceEqual([],
                                 self._search(experiment_id,
                                              param_expressions=[expr]))

        # metrics matches last reported timestamp for 'm_b'
        expr = self._metric_expression("m_b", "=", 4.0)
        self.assertSequenceEqual([r2],
                                 self._search(experiment_id,
                                              param_expressions=[expr]))

    def test_search_full(self):
        experiment_id = self._experiment_factory('search_params').experiment_id
        r1 = self._run_factory('r1',
                               experiment_id).to_mlflow_entity().info.run_uuid
        r2 = self._run_factory('r2',
                               experiment_id).to_mlflow_entity().info.run_uuid

        self.store.log_param(r1, entities.Param('generic_param', 'p_val'))
        self.store.log_param(r2, entities.Param('generic_param', 'p_val'))

        self.store.log_param(r1, entities.Param('p_a', 'abc'))
        self.store.log_param(r2, entities.Param('p_b', 'ABC'))

        self.store.log_metric(r1, entities.Metric("common", 1.0, 1))
        self.store.log_metric(r2, entities.Metric("common", 1.0, 1))

        self.store.log_metric(r1, entities.Metric("m_a", 2.0, 2))
        self.store.log_metric(r2, entities.Metric("m_b", 3.0, 2))
        self.store.log_metric(r2, entities.Metric("m_b", 4.0, 8))
        self.store.log_metric(r2, entities.Metric("m_b", 8.0, 3))

        p_expr = self._param_expression("generic_param", "=", "p_val")
        m_expr = self._metric_expression("common", "=", 1.0)
        self.assertSequenceEqual([r1, r2],
                                 self._search(experiment_id,
                                              param_expressions=[p_expr],
                                              metrics_expressions=[m_expr]))

        # all params and metrics match
        p_expr = self._param_expression("generic_param", "=", "p_val")
        m1_expr = self._metric_expression("common", "=", 1.0)
        m2_expr = self._metric_expression("m_a", ">", 1.0)
        self.assertSequenceEqual([r1],
                                 self._search(
                                     experiment_id,
                                     param_expressions=[p_expr],
                                     metrics_expressions=[m1_expr, m2_expr]))

        # test with mismatch param
        p_expr = self._param_expression("random_bad_name", "=", "p_val")
        m1_expr = self._metric_expression("common", "=", 1.0)
        m2_expr = self._metric_expression("m_a", ">", 1.0)
        self.assertSequenceEqual([],
                                 self._search(
                                     experiment_id,
                                     param_expressions=[p_expr],
                                     metrics_expressions=[m1_expr, m2_expr]))

        # test with mismatch metric
        p_expr = self._param_expression("generic_param", "=", "p_val")
        m1_expr = self._metric_expression("common", "=", 1.0)
        m2_expr = self._metric_expression("m_a", ">", 100.0)
        self.assertSequenceEqual([],
                                 self._search(
                                     experiment_id,
                                     param_expressions=[p_expr],
                                     metrics_expressions=[m1_expr, m2_expr]))
예제 #13
0
def _get_sqlalchemy_store(store_uri, artifact_uri):
    from mlflow.store.sqlalchemy_store import SqlAlchemyStore
    return SqlAlchemyStore(store_uri, artifact_uri)
예제 #14
0
class TestSqlAlchemyStoreSqliteInMemory(unittest.TestCase):

    def _setup_database(self, filename=''):
        # use a static file name to initialize sqllite to test retention.
        self.store = SqlAlchemyStore(DB_URI + filename, ARTIFACT_URI)

    def setUp(self):
        self.maxDiff = None  # print all differences on assert failures
        self.store = None
        self._setup_database()

    def tearDown(self):
        if self.store:
            models.Base.metadata.drop_all(self.store.engine)
        shutil.rmtree(ARTIFACT_URI)

    def _experiment_factory(self, names):
        if type(names) is list:
            return [self.store.create_experiment(name=name) for name in names]

        return self.store.create_experiment(name=names)

    def _verify_logged(self, run_uuid, metrics, params, tags):
        run = self.store.get_run(run_uuid)
        all_metrics = sum([self.store.get_metric_history(run_uuid, m.key)
                           for m in run.data.metrics], [])
        assert len(all_metrics) == len(metrics)
        logged_metrics = [(m.key, m.value, m.timestamp) for m in all_metrics]
        assert set(logged_metrics) == set([(m.key, m.value, m.timestamp) for m in metrics])
        logged_tags = set([(tag.key, tag.value) for tag in run.data.tags])
        assert set([(tag.key, tag.value) for tag in tags]) <= logged_tags
        assert len(run.data.params) == len(params)
        logged_params = [(param.key, param.value) for param in run.data.params]
        assert set(logged_params) == set([(param.key, param.value) for param in params])

    def test_default_experiment(self):
        experiments = self.store.list_experiments()
        self.assertEqual(len(experiments), 1)

        first = experiments[0]
        self.assertEqual(first.experiment_id, 0)
        self.assertEqual(first.name, "Default")

    def test_default_experiment_lifecycle(self):
        with TempDir(chdr=True) as tmp:
            tmp_file_name = "sqlite_file_to_lifecycle_test_{}.db".format(int(time.time()))
            self._setup_database("/" + tmp.path(tmp_file_name))

            default_experiment = self.store.get_experiment(experiment_id=0)
            self.assertEqual(default_experiment.name, Experiment.DEFAULT_EXPERIMENT_NAME)
            self.assertEqual(default_experiment.lifecycle_stage, entities.LifecycleStage.ACTIVE)

            self._experiment_factory('aNothEr')
            all_experiments = [e.name for e in self.store.list_experiments()]
            six.assertCountEqual(self, set(['aNothEr', 'Default']), set(all_experiments))

            self.store.delete_experiment(0)

            six.assertCountEqual(self, ['aNothEr'], [e.name for e in self.store.list_experiments()])
            another = self.store.get_experiment(1)
            self.assertEqual('aNothEr', another.name)

            default_experiment = self.store.get_experiment(experiment_id=0)
            self.assertEqual(default_experiment.name, Experiment.DEFAULT_EXPERIMENT_NAME)
            self.assertEqual(default_experiment.lifecycle_stage, entities.LifecycleStage.DELETED)

            # destroy SqlStore and make a new one
            del self.store
            self._setup_database("/" + tmp.path(tmp_file_name))

            # test that default experiment is not reactivated
            default_experiment = self.store.get_experiment(experiment_id=0)
            self.assertEqual(default_experiment.name, Experiment.DEFAULT_EXPERIMENT_NAME)
            self.assertEqual(default_experiment.lifecycle_stage, entities.LifecycleStage.DELETED)

            six.assertCountEqual(self, ['aNothEr'], [e.name for e in self.store.list_experiments()])
            all_experiments = [e.name for e in self.store.list_experiments(ViewType.ALL)]
            six.assertCountEqual(self, set(['aNothEr', 'Default']), set(all_experiments))

            # ensure that experiment ID dor active experiment is unchanged
            another = self.store.get_experiment(1)
            self.assertEqual('aNothEr', another.name)

            self.store = None

    def test_raise_duplicate_experiments(self):
        with self.assertRaises(Exception):
            self._experiment_factory(['test', 'test'])

    def test_raise_experiment_dont_exist(self):
        with self.assertRaises(Exception):
            self.store.get_experiment(experiment_id=100)

    def test_delete_experiment(self):
        experiments = self._experiment_factory(['morty', 'rick', 'rick and morty'])

        all_experiments = self.store.list_experiments()
        self.assertEqual(len(all_experiments), len(experiments) + 1)  # default

        exp_id = experiments[0]
        self.store.delete_experiment(exp_id)

        updated_exp = self.store.get_experiment(exp_id)
        self.assertEqual(updated_exp.lifecycle_stage, entities.LifecycleStage.DELETED)

        self.assertEqual(len(self.store.list_experiments()), len(all_experiments) - 1)

    def test_get_experiment(self):
        name = 'goku'
        experiment_id = self._experiment_factory(name)
        actual = self.store.get_experiment(experiment_id)
        self.assertEqual(actual.name, name)
        self.assertEqual(actual.experiment_id, experiment_id)

        actual_by_name = self.store.get_experiment_by_name(name)
        self.assertEqual(actual_by_name.name, name)
        self.assertEqual(actual_by_name.experiment_id, experiment_id)

    def test_list_experiments(self):
        testnames = ['blue', 'red', 'green']

        experiments = self._experiment_factory(testnames)
        actual = self.store.list_experiments()

        self.assertEqual(len(experiments) + 1, len(actual))  # default

        with self.store.ManagedSessionMaker() as session:
            for experiment_id in experiments:
                res = session.query(models.SqlExperiment).filter_by(
                    experiment_id=experiment_id).first()
                self.assertIn(res.name, testnames)
                self.assertEqual(res.experiment_id, experiment_id)

    def test_create_experiments(self):
        with self.store.ManagedSessionMaker() as session:
            result = session.query(models.SqlExperiment).all()
            self.assertEqual(len(result), 1)

        experiment_id = self.store.create_experiment(name='test exp')

        with self.store.ManagedSessionMaker() as session:
            result = session.query(models.SqlExperiment).all()
            self.assertEqual(len(result), 2)

            test_exp = session.query(models.SqlExperiment).filter_by(name='test exp').first()
            self.assertEqual(test_exp.experiment_id, experiment_id)
            self.assertEqual(test_exp.name, 'test exp')

        actual = self.store.get_experiment(experiment_id)
        self.assertEqual(actual.experiment_id, experiment_id)
        self.assertEqual(actual.name, 'test exp')

    def test_run_tag_model(self):
        # Create a run whose UUID we can reference when creating tag models.
        # `run_uuid` is a foreign key in the tags table; therefore, in order
        # to insert a tag with a given run UUID, the UUID must be present in
        # the runs table
        run = self._run_factory()
        with self.store.ManagedSessionMaker() as session:
            new_tag = models.SqlTag(run_uuid=run.info.run_uuid, key='test', value='val')
            session.add(new_tag)
            session.commit()
            added_tags = [
                tag for tag in session.query(models.SqlTag).all()
                if tag.key == new_tag.key
            ]
            self.assertEqual(len(added_tags), 1)
            added_tag = added_tags[0].to_mlflow_entity()
            self.assertEqual(added_tag.value, new_tag.value)

    def test_metric_model(self):
        # Create a run whose UUID we can reference when creating metric models.
        # `run_uuid` is a foreign key in the tags table; therefore, in order
        # to insert a metric with a given run UUID, the UUID must be present in
        # the runs table
        run = self._run_factory()
        with self.store.ManagedSessionMaker() as session:
            new_metric = models.SqlMetric(run_uuid=run.info.run_uuid, key='accuracy', value=0.89)
            session.add(new_metric)
            session.commit()
            metrics = session.query(models.SqlMetric).all()
            self.assertEqual(len(metrics), 1)

            added_metric = metrics[0].to_mlflow_entity()
            self.assertEqual(added_metric.value, new_metric.value)
            self.assertEqual(added_metric.key, new_metric.key)

    def test_param_model(self):
        # Create a run whose UUID we can reference when creating parameter models.
        # `run_uuid` is a foreign key in the tags table; therefore, in order
        # to insert a parameter with a given run UUID, the UUID must be present in
        # the runs table
        run = self._run_factory()
        with self.store.ManagedSessionMaker() as session:
            new_param = models.SqlParam(
                run_uuid=run.info.run_uuid, key='accuracy', value='test param')
            session.add(new_param)
            session.commit()
            params = session.query(models.SqlParam).all()
            self.assertEqual(len(params), 1)

            added_param = params[0].to_mlflow_entity()
            self.assertEqual(added_param.value, new_param.value)
            self.assertEqual(added_param.key, new_param.key)

    def test_run_needs_uuid(self):
        # Depending on the implementation, a NULL identity key may result in different
        # exceptions, including IntegrityError (sqlite) and FlushError (MysQL).
        # Therefore, we check for the more generic 'SQLAlchemyError'
        with self.assertRaises(MlflowException) as exception_context:
            warnings.simplefilter("ignore")
            with self.store.ManagedSessionMaker() as session, warnings.catch_warnings():
                run = models.SqlRun()
                session.add(run)
                warnings.resetwarnings()
        assert exception_context.exception.error_code == ErrorCode.Name(INTERNAL_ERROR)

    def test_run_data_model(self):
        with self.store.ManagedSessionMaker() as session:
            m1 = models.SqlMetric(key='accuracy', value=0.89)
            m2 = models.SqlMetric(key='recal', value=0.89)
            p1 = models.SqlParam(key='loss', value='test param')
            p2 = models.SqlParam(key='blue', value='test param')

            session.add_all([m1, m2, p1, p2])

            run_data = models.SqlRun(run_uuid=uuid.uuid4().hex)
            run_data.params.append(p1)
            run_data.params.append(p2)
            run_data.metrics.append(m1)
            run_data.metrics.append(m2)

            session.add(run_data)
            session.commit()

            run_datums = session.query(models.SqlRun).all()
            actual = run_datums[0]
            self.assertEqual(len(run_datums), 1)
            self.assertEqual(len(actual.params), 2)
            self.assertEqual(len(actual.metrics), 2)

    def test_run_info(self):
        experiment_id = self._experiment_factory('test exp')
        config = {
            'experiment_id': experiment_id,
            'name': 'test run',
            'user_id': 'Anderson',
            'run_uuid': 'test',
            'status': RunStatus.to_string(RunStatus.SCHEDULED),
            'source_type': SourceType.to_string(SourceType.LOCAL),
            'source_name': 'Python application',
            'entry_point_name': 'main.py',
            'start_time': int(time.time()),
            'end_time': int(time.time()),
            'source_version': mlflow.__version__,
            'lifecycle_stage': entities.LifecycleStage.ACTIVE,
            'artifact_uri': '//'
        }
        run = models.SqlRun(**config).to_mlflow_entity()

        for k, v in config.items():
            v2 = getattr(run.info, k)
            if k == 'source_type':
                self.assertEqual(v, SourceType.to_string(v2))
            elif k == 'status':
                self.assertEqual(v, RunStatus.to_string(v2))
            else:
                self.assertEqual(v, v2)

    def _get_run_configs(self, name='test', experiment_id=None, tags=(), parent_run_id=None):
        return {
            'experiment_id': experiment_id,
            'run_name': name,
            'user_id': 'Anderson',
            'source_type': SourceType.NOTEBOOK,
            'source_name': 'Python application',
            'entry_point_name': 'main.py',
            'start_time': int(time.time()),
            'source_version': mlflow.__version__,
            'tags': tags,
            'parent_run_id': parent_run_id,
        }

    def _run_factory(self, config=None):
        if not config:
            config = self._get_run_configs()

        experiment_id = config.get("experiment_id", None)
        if not experiment_id:
            experiment_id = self._experiment_factory('test exp')
            config["experiment_id"] = experiment_id

        return self.store.create_run(**config)

    def test_create_run_with_tags(self):
        run_name = "test-run-1"
        experiment_id = self._experiment_factory('test_create_run')
        tags = [RunTag('3', '4'), RunTag('1', '2')]
        expected = self._get_run_configs(name=run_name, experiment_id=experiment_id, tags=tags)

        actual = self.store.create_run(**expected)

        self.assertEqual(actual.info.experiment_id, experiment_id)
        self.assertEqual(actual.info.user_id, expected["user_id"])
        self.assertEqual(actual.info.name, run_name)
        self.assertEqual(actual.info.source_type, expected["source_type"])
        self.assertEqual(actual.info.source_name, expected["source_name"])
        self.assertEqual(actual.info.source_version, expected["source_version"])
        self.assertEqual(actual.info.entry_point_name, expected["entry_point_name"])
        self.assertEqual(actual.info.start_time, expected["start_time"])

        # Run creation should add an additional tag containing the run name. Check for
        # its existence
        self.assertEqual(len(actual.data.tags), len(tags) + 1)
        name_tag = models.SqlTag(key=MLFLOW_RUN_NAME, value=run_name).to_mlflow_entity()
        self.assertListEqual(actual.data.tags, tags + [name_tag])

    def test_create_run_with_parent_id(self):
        run_name = "test-run-1"
        parent_run_id = "parent_uuid_5"
        experiment_id = self._experiment_factory('test_create_run')
        expected = self._get_run_configs(
            name=run_name, experiment_id=experiment_id, parent_run_id=parent_run_id)

        actual = self.store.create_run(**expected)

        self.assertEqual(actual.info.experiment_id, experiment_id)
        self.assertEqual(actual.info.user_id, expected["user_id"])
        self.assertEqual(actual.info.name, run_name)
        self.assertEqual(actual.info.source_type, expected["source_type"])
        self.assertEqual(actual.info.source_name, expected["source_name"])
        self.assertEqual(actual.info.source_version, expected["source_version"])
        self.assertEqual(actual.info.entry_point_name, expected["entry_point_name"])
        self.assertEqual(actual.info.start_time, expected["start_time"])

        # Run creation should add two additional tags containing the run name and parent run id.
        # Check for the existence of these two tags
        self.assertEqual(len(actual.data.tags), 2)
        name_tag = models.SqlTag(key=MLFLOW_RUN_NAME, value=run_name).to_mlflow_entity()
        parent_id_tag = models.SqlTag(key=MLFLOW_PARENT_RUN_ID,
                                      value=parent_run_id).to_mlflow_entity()
        self.assertListEqual(actual.data.tags, [parent_id_tag, name_tag])

    def test_to_mlflow_entity(self):
        # Create a run and obtain an MLflow Run entity associated with the new run
        run = self._run_factory()

        self.assertIsInstance(run.info, entities.RunInfo)
        self.assertIsInstance(run.data, entities.RunData)

        for metric in run.data.metrics:
            self.assertIsInstance(metric, entities.Metric)

        for param in run.data.params:
            self.assertIsInstance(param, entities.Param)

        for tag in run.data.tags:
            self.assertIsInstance(tag, entities.RunTag)

    def test_delete_run(self):
        run = self._run_factory()

        self.store.delete_run(run.info.run_uuid)

        with self.store.ManagedSessionMaker() as session:
            actual = session.query(models.SqlRun).filter_by(run_uuid=run.info.run_uuid).first()
            self.assertEqual(actual.lifecycle_stage, entities.LifecycleStage.DELETED)

            deleted_run = self.store.get_run(run.info.run_uuid)
            self.assertEqual(actual.run_uuid, deleted_run.info.run_uuid)

    def test_log_metric(self):
        run = self._run_factory()

        tkey = 'blahmetric'
        tval = 100.0
        metric = entities.Metric(tkey, tval, int(time.time()))
        metric2 = entities.Metric(tkey, tval, int(time.time()) + 2)
        self.store.log_metric(run.info.run_uuid, metric)
        self.store.log_metric(run.info.run_uuid, metric2)

        run = self.store.get_run(run.info.run_uuid)
        found = False
        for m in run.data.metrics:
            if m.key == tkey and m.value == tval:
                found = True

        self.assertTrue(found)

        # SQL store _get_run method returns full history of recorded metrics.
        # Should return duplicates as well
        # MLflow RunData contains only the last reported values for metrics.
        with self.store.ManagedSessionMaker() as session:
            sql_run_metrics = self.store._get_run(session, run.info.run_uuid).metrics
            self.assertEqual(2, len(sql_run_metrics))
            self.assertEqual(1, len(run.data.metrics))

    def test_log_metric_uniqueness(self):
        run = self._run_factory()

        tkey = 'blahmetric'
        tval = 100.0
        metric = entities.Metric(tkey, tval, int(time.time()))
        metric2 = entities.Metric(tkey, 1.02, int(time.time()))
        self.store.log_metric(run.info.run_uuid, metric)

        with self.assertRaises(MlflowException) as e:
            self.store.log_metric(run.info.run_uuid, metric2)
        self.assertIn("must be unique. Metric already logged value", e.exception.message)

    def test_log_null_metric(self):
        run = self._run_factory()

        tkey = 'blahmetric'
        tval = None
        metric = entities.Metric(tkey, tval, int(time.time()))

        with self.assertRaises(MlflowException) as exception_context:
            self.store.log_metric(run.info.run_uuid, metric)
        assert exception_context.exception.error_code == ErrorCode.Name(INTERNAL_ERROR)

    def test_log_param(self):
        run = self._run_factory()

        tkey = 'blahmetric'
        tval = '100.0'
        param = entities.Param(tkey, tval)
        param2 = entities.Param('new param', 'new key')
        self.store.log_param(run.info.run_uuid, param)
        self.store.log_param(run.info.run_uuid, param2)

        run = self.store.get_run(run.info.run_uuid)
        self.assertEqual(2, len(run.data.params))

        found = False
        for m in run.data.params:
            if m.key == tkey and m.value == tval:
                found = True

        self.assertTrue(found)

    def test_log_param_uniqueness(self):
        run = self._run_factory()

        tkey = 'blahmetric'
        tval = '100.0'
        param = entities.Param(tkey, tval)
        param2 = entities.Param(tkey, 'newval')
        self.store.log_param(run.info.run_uuid, param)

        with self.assertRaises(MlflowException) as e:
            self.store.log_param(run.info.run_uuid, param2)
        self.assertIn("Changing param value is not allowed. Param with key=", e.exception.message)

    def test_log_empty_str(self):
        run = self._run_factory()

        tkey = 'blahmetric'
        tval = ''
        param = entities.Param(tkey, tval)
        param2 = entities.Param('new param', 'new key')
        self.store.log_param(run.info.run_uuid, param)
        self.store.log_param(run.info.run_uuid, param2)

        run = self.store.get_run(run.info.run_uuid)
        self.assertEqual(2, len(run.data.params))

        found = False
        for m in run.data.params:
            if m.key == tkey and m.value == tval:
                found = True

        self.assertTrue(found)

    def test_log_null_param(self):
        run = self._run_factory()

        tkey = 'blahmetric'
        tval = None
        param = entities.Param(tkey, tval)

        with self.assertRaises(MlflowException) as exception_context:
            self.store.log_param(run.info.run_uuid, param)
        assert exception_context.exception.error_code == ErrorCode.Name(INTERNAL_ERROR)

    def test_set_tag(self):
        run = self._run_factory()

        tkey = 'test tag'
        tval = 'a boogie'
        tag = entities.RunTag(tkey, tval)
        self.store.set_tag(run.info.run_uuid, tag)

        run = self.store.get_run(run.info.run_uuid)

        found = False
        for m in run.data.tags:
            if m.key == tkey and m.value == tval:
                found = True

        self.assertTrue(found)

    def test_get_metric_history(self):
        run = self._run_factory()

        key = 'test'
        expected = [
            models.SqlMetric(key=key, value=0.6, timestamp=1).to_mlflow_entity(),
            models.SqlMetric(key=key, value=0.7, timestamp=2).to_mlflow_entity()
        ]

        for metric in expected:
            self.store.log_metric(run.info.run_uuid, metric)

        actual = self.store.get_metric_history(run.info.run_uuid, key)

        six.assertCountEqual(self,
                             [(m.key, m.value, m.timestamp) for m in expected],
                             [(m.key, m.value, m.timestamp) for m in actual])

    def test_list_run_infos(self):
        experiment_id = self._experiment_factory('test_exp')
        r1 = self._run_factory(config=self._get_run_configs('t1', experiment_id)).info.run_uuid
        r2 = self._run_factory(config=self._get_run_configs('t2', experiment_id)).info.run_uuid

        def _runs(experiment_id, view_type):
            return [r.run_uuid for r in self.store.list_run_infos(experiment_id, view_type)]

        six.assertCountEqual(self, [r1, r2], _runs(experiment_id, ViewType.ALL))
        six.assertCountEqual(self, [r1, r2], _runs(experiment_id, ViewType.ACTIVE_ONLY))
        self.assertEqual(0, len(_runs(experiment_id, ViewType.DELETED_ONLY)))

        self.store.delete_run(r1)
        six.assertCountEqual(self, [r1, r2], _runs(experiment_id, ViewType.ALL))
        six.assertCountEqual(self, [r2], _runs(experiment_id, ViewType.ACTIVE_ONLY))
        six.assertCountEqual(self, [r1], _runs(experiment_id, ViewType.DELETED_ONLY))

    def test_rename_experiment(self):
        new_name = 'new name'
        experiment_id = self._experiment_factory('test name')
        self.store.rename_experiment(experiment_id, new_name)

        renamed_experiment = self.store.get_experiment(experiment_id)

        self.assertEqual(renamed_experiment.name, new_name)

    def test_update_run_info(self):
        run = self._run_factory()

        new_status = entities.RunStatus.FINISHED
        endtime = int(time.time())

        actual = self.store.update_run_info(run.info.run_uuid, new_status, endtime)

        self.assertEqual(actual.status, new_status)
        self.assertEqual(actual.end_time, endtime)

    def test_restore_experiment(self):
        experiment_id = self._experiment_factory('helloexp')
        exp = self.store.get_experiment(experiment_id)
        self.assertEqual(exp.lifecycle_stage, entities.LifecycleStage.ACTIVE)

        experiment_id = exp.experiment_id
        self.store.delete_experiment(experiment_id)

        deleted = self.store.get_experiment(experiment_id)
        self.assertEqual(deleted.experiment_id, experiment_id)
        self.assertEqual(deleted.lifecycle_stage, entities.LifecycleStage.DELETED)

        self.store.restore_experiment(exp.experiment_id)
        restored = self.store.get_experiment(exp.experiment_id)
        self.assertEqual(restored.experiment_id, experiment_id)
        self.assertEqual(restored.lifecycle_stage, entities.LifecycleStage.ACTIVE)

    def test_delete_restore_run(self):
        run = self._run_factory()
        self.assertEqual(run.info.lifecycle_stage, entities.LifecycleStage.ACTIVE)

        with self.assertRaises(MlflowException) as e:
            self.store.restore_run(run.info.run_uuid)
        self.assertIn("must be in 'deleted' state", e.exception.message)

        self.store.delete_run(run.info.run_uuid)
        with self.assertRaises(MlflowException) as e:
            self.store.delete_run(run.info.run_uuid)
        self.assertIn("must be in 'active' state", e.exception.message)

        deleted = self.store.get_run(run.info.run_uuid)
        self.assertEqual(deleted.info.run_uuid, run.info.run_uuid)
        self.assertEqual(deleted.info.lifecycle_stage, entities.LifecycleStage.DELETED)

        self.store.restore_run(run.info.run_uuid)
        with self.assertRaises(MlflowException) as e:
            self.store.restore_run(run.info.run_uuid)
            self.assertIn("must be in 'deleted' state", e.exception.message)
        restored = self.store.get_run(run.info.run_uuid)
        self.assertEqual(restored.info.run_uuid, run.info.run_uuid)
        self.assertEqual(restored.info.lifecycle_stage, entities.LifecycleStage.ACTIVE)

    def test_error_logging_to_deleted_run(self):
        exp = self._experiment_factory('error_logging')
        run_uuid = self._run_factory(self._get_run_configs(experiment_id=exp)).info.run_uuid

        self.store.delete_run(run_uuid)
        self.assertEqual(self.store.get_run(run_uuid).info.lifecycle_stage,
                         entities.LifecycleStage.DELETED)
        with self.assertRaises(MlflowException) as e:
            self.store.log_param(run_uuid, entities.Param("p1345", "v1"))
        self.assertIn("must be in 'active' state", e.exception.message)

        with self.assertRaises(MlflowException) as e:
            self.store.log_metric(run_uuid, entities.Metric("m1345", 1.0, 123))
        self.assertIn("must be in 'active' state", e.exception.message)

        with self.assertRaises(MlflowException) as e:
            self.store.set_tag(run_uuid, entities.RunTag("t1345", "tv1"))
        self.assertIn("must be in 'active' state", e.exception.message)

        # restore this run and try again
        self.store.restore_run(run_uuid)
        self.assertEqual(self.store.get_run(run_uuid).info.lifecycle_stage,
                         entities.LifecycleStage.ACTIVE)
        self.store.log_param(run_uuid, entities.Param("p1345", "v22"))
        self.store.log_metric(run_uuid, entities.Metric("m1345", 34.0, 85))  # earlier timestamp
        self.store.set_tag(run_uuid, entities.RunTag("t1345", "tv44"))

        run = self.store.get_run(run_uuid)
        assert len(run.data.params) == 1
        p = run.data.params[0]
        self.assertEqual(p.key, "p1345")
        self.assertEqual(p.value, "v22")
        assert len(run.data.metrics) == 1
        m = run.data.metrics[0]
        self.assertEqual(m.key, "m1345")
        self.assertEqual(m.value, 34.0)
        run = self.store.get_run(run_uuid)
        self.assertEqual([("p1345", "v22")],
                         [(p.key, p.value) for p in run.data.params if p.key == "p1345"])
        self.assertEqual([("m1345", 34.0, 85)],
                         [(m.key, m.value, m.timestamp)
                          for m in run.data.metrics if m.key == "m1345"])
        self.assertEqual([("t1345", "tv44")],
                         [(t.key, t.value) for t in run.data.tags if t.key == "t1345"])

    # Tests for Search API
    def _search(self, experiment_id, metrics_expressions=None, param_expressions=None,
                run_view_type=ViewType.ALL):
        search_runs = SearchRuns()
        search_runs.anded_expressions.extend(metrics_expressions or [])
        search_runs.anded_expressions.extend(param_expressions or [])
        search_filter = SearchFilter(search_runs)
        return [r.info.run_uuid
                for r in self.store.search_runs([experiment_id], search_filter, run_view_type)]

    def _param_expression(self, key, comparator, val):
        expr = SearchExpression()
        expr.parameter.key = key
        expr.parameter.string.comparator = comparator
        expr.parameter.string.value = val
        return expr

    def _metric_expression(self, key, comparator, val):
        expr = SearchExpression()
        expr.metric.key = key
        expr.metric.double.comparator = comparator
        expr.metric.double.value = val
        return expr

    def test_search_vanilla(self):
        exp = self._experiment_factory('search_vanilla')
        runs = [self._run_factory(self._get_run_configs('r_%d' % r, exp)).info.run_uuid
                for r in range(3)]

        six.assertCountEqual(self, runs, self._search(exp, run_view_type=ViewType.ALL))
        six.assertCountEqual(self, runs, self._search(exp, run_view_type=ViewType.ACTIVE_ONLY))
        six.assertCountEqual(self, [], self._search(exp, run_view_type=ViewType.DELETED_ONLY))

        first = runs[0]

        self.store.delete_run(first)
        six.assertCountEqual(self, runs, self._search(exp, run_view_type=ViewType.ALL))
        six.assertCountEqual(self, runs[1:], self._search(exp, run_view_type=ViewType.ACTIVE_ONLY))
        six.assertCountEqual(self, [first], self._search(exp, run_view_type=ViewType.DELETED_ONLY))

        self.store.restore_run(first)
        six.assertCountEqual(self, runs, self._search(exp, run_view_type=ViewType.ALL))
        six.assertCountEqual(self, runs, self._search(exp, run_view_type=ViewType.ACTIVE_ONLY))
        six.assertCountEqual(self, [], self._search(exp, run_view_type=ViewType.DELETED_ONLY))

    def test_search_params(self):
        experiment_id = self._experiment_factory('search_params')
        r1 = self._run_factory(self._get_run_configs('r1', experiment_id)).info.run_uuid
        r2 = self._run_factory(self._get_run_configs('r2', experiment_id)).info.run_uuid

        self.store.log_param(r1, entities.Param('generic_param', 'p_val'))
        self.store.log_param(r2, entities.Param('generic_param', 'p_val'))

        self.store.log_param(r1, entities.Param('generic_2', 'some value'))
        self.store.log_param(r2, entities.Param('generic_2', 'another value'))

        self.store.log_param(r1, entities.Param('p_a', 'abc'))
        self.store.log_param(r2, entities.Param('p_b', 'ABC'))

        # test search returns both runs
        expr = self._param_expression("generic_param", "=", "p_val")
        six.assertCountEqual(self, [r1, r2], self._search(experiment_id, param_expressions=[expr]))

        # test search returns appropriate run (same key different values per run)
        expr = self._param_expression("generic_2", "=", "some value")
        six.assertCountEqual(self, [r1], self._search(experiment_id, param_expressions=[expr]))
        expr = self._param_expression("generic_2", "=", "another value")
        six.assertCountEqual(self, [r2], self._search(experiment_id, param_expressions=[expr]))

        expr = self._param_expression("generic_param", "=", "wrong_val")
        six.assertCountEqual(self, [], self._search(experiment_id, param_expressions=[expr]))

        expr = self._param_expression("generic_param", "!=", "p_val")
        six.assertCountEqual(self, [], self._search(experiment_id, param_expressions=[expr]))

        expr = self._param_expression("generic_param", "!=", "wrong_val")
        six.assertCountEqual(self, [r1, r2], self._search(experiment_id, param_expressions=[expr]))
        expr = self._param_expression("generic_2", "!=", "wrong_val")
        six.assertCountEqual(self, [r1, r2], self._search(experiment_id, param_expressions=[expr]))

        expr = self._param_expression("p_a", "=", "abc")
        six.assertCountEqual(self, [r1], self._search(experiment_id, param_expressions=[expr]))

        expr = self._param_expression("p_b", "=", "ABC")
        six.assertCountEqual(self, [r2], self._search(experiment_id, param_expressions=[expr]))

    def test_search_metrics(self):
        experiment_id = self._experiment_factory('search_params')
        r1 = self._run_factory(self._get_run_configs('r1', experiment_id)).info.run_uuid
        r2 = self._run_factory(self._get_run_configs('r2', experiment_id)).info.run_uuid

        self.store.log_metric(r1, entities.Metric("common", 1.0, 1))
        self.store.log_metric(r2, entities.Metric("common", 1.0, 1))

        self.store.log_metric(r1, entities.Metric("measure_a", 1.0, 1))
        self.store.log_metric(r2, entities.Metric("measure_a", 200.0, 2))
        self.store.log_metric(r2, entities.Metric("measure_a", 400.0, 3))

        self.store.log_metric(r1, entities.Metric("m_a", 2.0, 2))
        self.store.log_metric(r2, entities.Metric("m_b", 3.0, 2))
        self.store.log_metric(r2, entities.Metric("m_b", 4.0, 8))  # this is last timestamp
        self.store.log_metric(r2, entities.Metric("m_b", 8.0, 3))

        expr = self._metric_expression("common", "=", 1.0)
        six.assertCountEqual(self, [r1, r2], self._search(experiment_id, param_expressions=[expr]))

        expr = self._metric_expression("common", ">", 0.0)
        six.assertCountEqual(self, [r1, r2], self._search(experiment_id, param_expressions=[expr]))

        expr = self._metric_expression("common", ">=", 0.0)
        six.assertCountEqual(self, [r1, r2], self._search(experiment_id, param_expressions=[expr]))

        expr = self._metric_expression("common", "<", 4.0)
        six.assertCountEqual(self, [r1, r2], self._search(experiment_id, param_expressions=[expr]))

        expr = self._metric_expression("common", "<=", 4.0)
        six.assertCountEqual(self, [r1, r2], self._search(experiment_id, param_expressions=[expr]))

        expr = self._metric_expression("common", "!=", 1.0)
        six.assertCountEqual(self, [], self._search(experiment_id, param_expressions=[expr]))

        expr = self._metric_expression("common", ">=", 3.0)
        six.assertCountEqual(self, [], self._search(experiment_id, param_expressions=[expr]))

        expr = self._metric_expression("common", "<=", 0.75)
        six.assertCountEqual(self, [], self._search(experiment_id, param_expressions=[expr]))

        # tests for same metric name across runs with different values and timestamps
        expr = self._metric_expression("measure_a", ">", 0.0)
        six.assertCountEqual(self, [r1, r2], self._search(experiment_id, param_expressions=[expr]))

        expr = self._metric_expression("measure_a", "<", 50.0)
        six.assertCountEqual(self, [r1], self._search(experiment_id, param_expressions=[expr]))

        expr = self._metric_expression("measure_a", "<", 1000.0)
        six.assertCountEqual(self, [r1, r2], self._search(experiment_id, param_expressions=[expr]))

        expr = self._metric_expression("measure_a", "!=", -12.0)
        six.assertCountEqual(self, [r1, r2], self._search(experiment_id, param_expressions=[expr]))

        expr = self._metric_expression("measure_a", ">", 50.0)
        six.assertCountEqual(self, [r2], self._search(experiment_id, param_expressions=[expr]))

        expr = self._metric_expression("measure_a", "=", 1.0)
        six.assertCountEqual(self, [r1], self._search(experiment_id, param_expressions=[expr]))

        expr = self._metric_expression("measure_a", "=", 400.0)
        six.assertCountEqual(self, [r2], self._search(experiment_id, param_expressions=[expr]))

        # test search with unique metric keys
        expr = self._metric_expression("m_a", ">", 1.0)
        six.assertCountEqual(self, [r1], self._search(experiment_id, param_expressions=[expr]))

        expr = self._metric_expression("m_b", ">", 1.0)
        six.assertCountEqual(self, [r2], self._search(experiment_id, param_expressions=[expr]))

        # there is a recorded metric this threshold but not last timestamp
        expr = self._metric_expression("m_b", ">", 5.0)
        six.assertCountEqual(self, [], self._search(experiment_id, param_expressions=[expr]))

        # metrics matches last reported timestamp for 'm_b'
        expr = self._metric_expression("m_b", "=", 4.0)
        six.assertCountEqual(self, [r2], self._search(experiment_id, param_expressions=[expr]))

    def test_search_full(self):
        experiment_id = self._experiment_factory('search_params')
        r1 = self._run_factory(self._get_run_configs('r1', experiment_id)).info.run_uuid
        r2 = self._run_factory(self._get_run_configs('r2', experiment_id)).info.run_uuid

        self.store.log_param(r1, entities.Param('generic_param', 'p_val'))
        self.store.log_param(r2, entities.Param('generic_param', 'p_val'))

        self.store.log_param(r1, entities.Param('p_a', 'abc'))
        self.store.log_param(r2, entities.Param('p_b', 'ABC'))

        self.store.log_metric(r1, entities.Metric("common", 1.0, 1))
        self.store.log_metric(r2, entities.Metric("common", 1.0, 1))

        self.store.log_metric(r1, entities.Metric("m_a", 2.0, 2))
        self.store.log_metric(r2, entities.Metric("m_b", 3.0, 2))
        self.store.log_metric(r2, entities.Metric("m_b", 4.0, 8))
        self.store.log_metric(r2, entities.Metric("m_b", 8.0, 3))

        p_expr = self._param_expression("generic_param", "=", "p_val")
        m_expr = self._metric_expression("common", "=", 1.0)
        six.assertCountEqual(self, [r1, r2], self._search(experiment_id,
                                                          param_expressions=[p_expr],
                                                          metrics_expressions=[m_expr]))

        # all params and metrics match
        p_expr = self._param_expression("generic_param", "=", "p_val")
        m1_expr = self._metric_expression("common", "=", 1.0)
        m2_expr = self._metric_expression("m_a", ">", 1.0)
        six.assertCountEqual(self, [r1], self._search(experiment_id,
                                                      param_expressions=[p_expr],
                                                      metrics_expressions=[m1_expr, m2_expr]))

        # test with mismatch param
        p_expr = self._param_expression("random_bad_name", "=", "p_val")
        m1_expr = self._metric_expression("common", "=", 1.0)
        m2_expr = self._metric_expression("m_a", ">", 1.0)
        six.assertCountEqual(self, [], self._search(experiment_id,
                                                    param_expressions=[p_expr],
                                                    metrics_expressions=[m1_expr, m2_expr]))

        # test with mismatch metric
        p_expr = self._param_expression("generic_param", "=", "p_val")
        m1_expr = self._metric_expression("common", "=", 1.0)
        m2_expr = self._metric_expression("m_a", ">", 100.0)
        six.assertCountEqual(self, [], self._search(experiment_id,
                                                    param_expressions=[p_expr],
                                                    metrics_expressions=[m1_expr, m2_expr]))

    def test_log_batch(self):
        experiment_id = self._experiment_factory('log_batch')
        run_uuid = self._run_factory(self._get_run_configs('r1', experiment_id)).info.run_uuid
        metric_entities = [Metric("m1", 0.87, 12345), Metric("m2", 0.49, 12345)]
        param_entities = [Param("p1", "p1val"), Param("p2", "p2val")]
        tag_entities = [RunTag("t1", "t1val"), RunTag("t2", "t2val")]
        self.store.log_batch(
            run_id=run_uuid, metrics=metric_entities, params=param_entities, tags=tag_entities)
        run = self.store.get_run(run_uuid)
        tags = [(t.key, t.value) for t in run.data.tags]
        metrics = [(m.key, m.value, m.timestamp) for m in run.data.metrics]
        params = [(p.key, p.value) for p in run.data.params]
        assert set([("t1", "t1val"), ("t2", "t2val")]) <= set(tags)
        assert set(metrics) == set([("m1", 0.87, 12345), ("m2", 0.49, 12345)])
        assert set(params) == set([("p1", "p1val"), ("p2", "p2val")])

    def test_log_batch_limits(self):
        # Test that log batch at the maximum allowed request size succeeds (i.e doesn't hit
        # SQL limitations, etc)
        experiment_id = self._experiment_factory('log_batch_limits')
        run_uuid = self._run_factory(self._get_run_configs('r1', experiment_id)).info.run_uuid
        metric_tuples = [("m%s" % i, i, 12345) for i in range(1000)]
        metric_entities = [Metric(*metric_tuple) for metric_tuple in metric_tuples]
        self.store.log_batch(run_id=run_uuid, metrics=metric_entities, params=[], tags=[])
        run = self.store.get_run(run_uuid)
        metrics = [(m.key, m.value, m.timestamp) for m in run.data.metrics]
        assert set(metrics) == set(metric_tuples)

    def test_log_batch_param_overwrite_disallowed(self):
        # Test that attempting to overwrite a param via log_batch results in an exception and that
        # no partial data is logged
        run = self._run_factory()
        tkey = 'my-param'
        param = entities.Param(tkey, 'orig-val')
        self.store.log_param(run.info.run_uuid, param)

        overwrite_param = entities.Param(tkey, 'newval')
        tag = entities.RunTag("tag-key", "tag-val")
        metric = entities.Metric("metric-key", 3.0, 12345)
        with self.assertRaises(MlflowException) as e:
            self.store.log_batch(run.info.run_uuid, metrics=[metric], params=[overwrite_param],
                                 tags=[tag])
        self.assertIn("Changing param value is not allowed. Param with key=", e.exception.message)
        assert e.exception.error_code == ErrorCode.Name(INVALID_PARAMETER_VALUE)
        self._verify_logged(run.info.run_uuid, metrics=[], params=[param], tags=[])

    def test_log_batch_param_overwrite_disallowed_single_req(self):
        # Test that attempting to overwrite a param via log_batch results in an exception
        run = self._run_factory()
        pkey = "common-key"
        param0 = entities.Param(pkey, "orig-val")
        param1 = entities.Param(pkey, 'newval')
        tag = entities.RunTag("tag-key", "tag-val")
        metric = entities.Metric("metric-key", 3.0, 12345)
        with self.assertRaises(MlflowException) as e:
            self.store.log_batch(run.info.run_uuid, metrics=[metric], params=[param0, param1],
                                 tags=[tag])
        self.assertIn("Changing param value is not allowed. Param with key=", e.exception.message)
        assert e.exception.error_code == ErrorCode.Name(INVALID_PARAMETER_VALUE)
        self._verify_logged(run.info.run_uuid, metrics=[], params=[param0], tags=[])

    def test_log_batch_accepts_empty_payload(self):
        run = self._run_factory()
        self.store.log_batch(run.info.run_uuid, metrics=[], params=[], tags=[])
        self._verify_logged(run.info.run_uuid, metrics=[], params=[], tags=[])

    def test_log_batch_internal_error(self):
        # Verify that internal errors during the DB save step for log_batch result in
        # MlflowExceptions
        run = self._run_factory()

        def _raise_exception_fn(*args, **kwargs):  # pylint: disable=unused-argument
            raise Exception("Some internal error")
        with mock.patch("mlflow.store.sqlalchemy_store.SqlAlchemyStore.log_metric") as metric_mock,\
                mock.patch(
                    "mlflow.store.sqlalchemy_store.SqlAlchemyStore.log_param") as param_mock,\
                mock.patch("mlflow.store.sqlalchemy_store.SqlAlchemyStore.set_tag") as tags_mock:
            metric_mock.side_effect = _raise_exception_fn
            param_mock.side_effect = _raise_exception_fn
            tags_mock.side_effect = _raise_exception_fn
            for kwargs in [{"metrics": [Metric("a", 3, 1)]}, {"params": [Param("b", "c")]},
                           {"tags": [RunTag("c", "d")]}]:
                log_batch_kwargs = {"metrics": [], "params": [], "tags": []}
                log_batch_kwargs.update(kwargs)
                with self.assertRaises(MlflowException) as e:
                    self.store.log_batch(run.info.run_uuid, **log_batch_kwargs)
                self.assertIn(str(e.exception.message), "Some internal error")

    def test_log_batch_nonexistent_run(self):
        with self.assertRaises(MlflowException) as e:
            self.store.log_batch("bad-run-uuid", [], [], [])
        assert e.exception.error_code == ErrorCode.Name(RESOURCE_DOES_NOT_EXIST)
        assert "Run with id=bad-run-uuid not found" in e.exception.message

    def test_log_batch_params_idempotency(self):
        run = self._run_factory()
        params = [Param("p-key", "p-val")]
        self.store.log_batch(run.info.run_uuid, metrics=[], params=params, tags=[])
        self.store.log_batch(run.info.run_uuid, metrics=[], params=params, tags=[])
        self._verify_logged(run.info.run_uuid, metrics=[], params=params, tags=[])

    def test_log_batch_tags_idempotency(self):
        run = self._run_factory()
        self.store.log_batch(
            run.info.run_uuid, metrics=[], params=[], tags=[RunTag("t-key", "t-val")])
        self.store.log_batch(
            run.info.run_uuid, metrics=[], params=[], tags=[RunTag("t-key", "t-val")])
        self._verify_logged(
            run.info.run_uuid, metrics=[], params=[], tags=[RunTag("t-key", "t-val")])

    def test_log_batch_allows_tag_overwrite(self):
        run = self._run_factory()
        self.store.log_batch(
            run.info.run_uuid, metrics=[], params=[], tags=[RunTag("t-key", "val")])
        self.store.log_batch(
            run.info.run_uuid, metrics=[], params=[], tags=[RunTag("t-key", "newval")])
        self._verify_logged(
            run.info.run_uuid, metrics=[], params=[], tags=[RunTag("t-key", "newval")])

    def test_log_batch_allows_tag_overwrite_single_req(self):
        run = self._run_factory()
        tags = [RunTag("t-key", "val"), RunTag("t-key", "newval")]
        self.store.log_batch(run.info.run_uuid, metrics=[], params=[], tags=tags)
        self._verify_logged(run.info.run_uuid, metrics=[], params=[], tags=[tags[-1]])

    def test_log_batch_same_metric_repeated_single_req(self):
        run = self._run_factory()
        metric0 = Metric(key="metric-key", value=1, timestamp=2)
        metric1 = Metric(key="metric-key", value=2, timestamp=3)
        self.store.log_batch(run.info.run_uuid, params=[], metrics=[metric0, metric1], tags=[])
        self._verify_logged(run.info.run_uuid, params=[], metrics=[metric0, metric1], tags=[])

    def test_log_batch_same_metric_repeated_multiple_reqs(self):
        run = self._run_factory()
        metric0 = Metric(key="metric-key", value=1, timestamp=2)
        metric1 = Metric(key="metric-key", value=2, timestamp=3)
        self.store.log_batch(run.info.run_uuid, params=[], metrics=[metric0], tags=[])
        self._verify_logged(run.info.run_uuid, params=[], metrics=[metric0], tags=[])
        self.store.log_batch(run.info.run_uuid, params=[], metrics=[metric1], tags=[])
        self._verify_logged(run.info.run_uuid, params=[], metrics=[metric0, metric1], tags=[])
예제 #15
0
 def _assert_invalid_schema(engine):
     with pytest.raises(MlflowException) as ex:
         SqlAlchemyStore._verify_schema(engine)
         assert ex.message.contains("Detected out-of-date database schema.")
예제 #16
0
class TestSqlAlchemyStoreSqliteInMemory(unittest.TestCase):
    def _setup_database(self, filename=''):
        # use a static file name to initialize sqllite to test retention.
        self.store = SqlAlchemyStore(DB_URI + filename, ARTIFACT_URI)
        self.session = self.store.session

    def setUp(self):
        self.maxDiff = None  # print all differences on assert failures
        self.store = None
        self.session = None
        self._setup_database()

    def tearDown(self):
        if self.store:
            models.Base.metadata.drop_all(self.store.engine)

    def _experiment_factory(self, names):
        if type(names) is list:
            return [self.store.create_experiment(name=name) for name in names]

        return self.store.create_experiment(name=names)

    def test_default_experiment(self):
        experiments = self.store.list_experiments()
        self.assertEqual(len(experiments), 1)

        first = experiments[0]
        self.assertEqual(first.experiment_id, 0)
        self.assertEqual(first.name, "Default")

    def test_default_experiment_lifecycle(self):
        with TempDir(chdr=True) as tmp:
            tmp_file_name = "sqlite_file_to_lifecycle_test_{}.db".format(
                int(time.time()))
            self._setup_database("/" + tmp.path(tmp_file_name))
            default = self.session.query(
                models.SqlExperiment).filter_by(name='Default').first()
            self.assertEqual(default.experiment_id, 0)
            self.assertEqual(default.lifecycle_stage,
                             entities.LifecycleStage.ACTIVE)

            self._experiment_factory('aNothEr')
            all_experiments = [e.name for e in self.store.list_experiments()]

            self.assertSequenceEqual(set(['aNothEr', 'Default']),
                                     set(all_experiments))

            self.store.delete_experiment(0)

            self.assertSequenceEqual(
                ['aNothEr'], [e.name for e in self.store.list_experiments()])
            another = self.store.get_experiment(1)
            self.assertEqual('aNothEr', another.name)

            default = self.session.query(
                models.SqlExperiment).filter_by(name='Default').first()
            self.assertEqual(default.experiment_id, 0)
            self.assertEqual(default.lifecycle_stage,
                             entities.LifecycleStage.DELETED)

            # destroy SqlStore and make a new one
            del self.store
            self._setup_database("/" + tmp.path(tmp_file_name))

            # test that default experiment is not reactivated
            default = self.session.query(
                models.SqlExperiment).filter_by(name='Default').first()
            self.assertEqual(default.experiment_id, 0)
            self.assertEqual(default.lifecycle_stage,
                             entities.LifecycleStage.DELETED)

            self.assertSequenceEqual(
                ['aNothEr'], [e.name for e in self.store.list_experiments()])
            all_experiments = [
                e.name for e in self.store.list_experiments(ViewType.ALL)
            ]
            self.assertSequenceEqual(set(['aNothEr', 'Default']),
                                     set(all_experiments))

            # ensure that experiment ID dor active experiment is unchanged
            another = self.store.get_experiment(1)
            self.assertEqual('aNothEr', another.name)

            self.session.close()
            self.store = None

    def test_raise_duplicate_experiments(self):
        with self.assertRaises(Exception):
            self._experiment_factory(['test', 'test'])

    def test_raise_experiment_dont_exist(self):
        with self.assertRaises(Exception):
            self.store.get_experiment(experiment_id=100)

    def test_delete_experiment(self):
        experiments = self._experiment_factory(
            ['morty', 'rick', 'rick and morty'])

        all_experiments = self.store.list_experiments()
        self.assertEqual(len(all_experiments), len(experiments) + 1)  # default

        exp = experiments[0]
        self.store.delete_experiment(exp)

        actual = self.session.query(models.SqlExperiment).get(exp)
        self.assertEqual(len(self.store.list_experiments()),
                         len(all_experiments) - 1)

        self.assertEqual(actual.lifecycle_stage,
                         entities.LifecycleStage.DELETED)

    def test_get_experiment(self):
        name = 'goku'
        experiment_id = self._experiment_factory(name)
        actual = self.store.get_experiment(experiment_id)
        self.assertEqual(actual.name, name)
        self.assertEqual(actual.experiment_id, experiment_id)

    def test_list_experiments(self):
        testnames = ['blue', 'red', 'green']

        experiments = self._experiment_factory(testnames)
        actual = self.store.list_experiments()

        self.assertEqual(len(experiments) + 1, len(actual))  # default

        for experiment_id in experiments:
            res = self.session.query(models.SqlExperiment).filter_by(
                experiment_id=experiment_id).first()
            self.assertIn(res.name, testnames)
            self.assertEqual(res.experiment_id, experiment_id)

    def test_create_experiments(self):
        result = self.session.query(models.SqlExperiment).all()
        self.assertEqual(len(result), 1)

        experiment_id = self.store.create_experiment(name='test exp')
        result = self.session.query(models.SqlExperiment).all()
        self.assertEqual(len(result), 2)

        test_exp = self.session.query(
            models.SqlExperiment).filter_by(name='test exp').first()

        self.assertEqual(test_exp.experiment_id, experiment_id)
        self.assertEqual(test_exp.name, 'test exp')

        actual = self.store.get_experiment(experiment_id)
        self.assertEqual(actual.experiment_id, experiment_id)
        self.assertEqual(actual.name, 'test exp')

    def test_run_tag_model(self):
        run_data = models.SqlTag(run_uuid='tuuid', key='test', value='val')
        self.session.add(run_data)
        self.session.commit()
        tags = self.session.query(models.SqlTag).all()
        self.assertEqual(len(tags), 1)

        actual = tags[0].to_mlflow_entity()

        self.assertEqual(actual.value, run_data.value)
        self.assertEqual(actual.key, run_data.key)

    def test_metric_model(self):
        run_data = models.SqlMetric(run_uuid='testuid',
                                    key='accuracy',
                                    value=0.89)
        self.session.add(run_data)
        self.session.commit()
        metrics = self.session.query(models.SqlMetric).all()
        self.assertEqual(len(metrics), 1)

        actual = metrics[0].to_mlflow_entity()

        self.assertEqual(actual.value, run_data.value)
        self.assertEqual(actual.key, run_data.key)

    def test_param_model(self):
        run_data = models.SqlParam(run_uuid='test',
                                   key='accuracy',
                                   value='test param')
        self.session.add(run_data)
        self.session.commit()
        params = self.session.query(models.SqlParam).all()
        self.assertEqual(len(params), 1)

        actual = params[0].to_mlflow_entity()

        self.assertEqual(actual.value, run_data.value)
        self.assertEqual(actual.key, run_data.key)

    def test_run_needs_uuid(self):
        run = models.SqlRun()
        self.session.add(run)

        with self.assertRaises(sqlalchemy.exc.IntegrityError):
            warnings.simplefilter("ignore")
            with warnings.catch_warnings():
                self.session.commit()
            warnings.resetwarnings()

    def test_run_data_model(self):
        m1 = models.SqlMetric(key='accuracy', value=0.89)
        m2 = models.SqlMetric(key='recal', value=0.89)
        p1 = models.SqlParam(key='loss', value='test param')
        p2 = models.SqlParam(key='blue', value='test param')

        self.session.add_all([m1, m2, p1, p2])

        run_data = models.SqlRun(run_uuid=uuid.uuid4().hex)
        run_data.params.append(p1)
        run_data.params.append(p2)
        run_data.metrics.append(m1)
        run_data.metrics.append(m2)

        self.session.add(run_data)
        self.session.commit()

        run_datums = self.session.query(models.SqlRun).all()
        actual = run_datums[0]
        self.assertEqual(len(run_datums), 1)
        self.assertEqual(len(actual.params), 2)
        self.assertEqual(len(actual.metrics), 2)

    def test_run_info(self):
        experiment_id = self._experiment_factory('test exp')
        config = {
            'experiment_id': experiment_id,
            'name': 'test run',
            'user_id': 'Anderson',
            'run_uuid': 'test',
            'status': RunStatus.to_string(RunStatus.SCHEDULED),
            'source_type': SourceType.to_string(SourceType.LOCAL),
            'source_name': 'Python application',
            'entry_point_name': 'main.py',
            'start_time': int(time.time()),
            'end_time': int(time.time()),
            'source_version': mlflow.__version__,
            'lifecycle_stage': entities.LifecycleStage.ACTIVE,
            'artifact_uri': '//'
        }
        run = models.SqlRun(**config).to_mlflow_entity()

        for k, v in config.items():
            v2 = getattr(run.info, k)
            if k == 'source_type':
                self.assertEqual(v, SourceType.to_string(v2))
            elif k == 'status':
                self.assertEqual(v, RunStatus.to_string(v2))
            else:
                self.assertEqual(v, v2)

    def _get_run_configs(self, name='test', experiment_id=None):
        return {
            'experiment_id': experiment_id,
            'name': name,
            'user_id': 'Anderson',
            'run_uuid': uuid.uuid4().hex,
            'status': RunStatus.to_string(RunStatus.SCHEDULED),
            'source_type': SourceType.to_string(SourceType.NOTEBOOK),
            'source_name': 'Python application',
            'entry_point_name': 'main.py',
            'start_time': int(time.time()),
            'end_time': int(time.time()),
            'source_version': mlflow.__version__,
            'lifecycle_stage': entities.LifecycleStage.ACTIVE,
            'artifact_uri': '//'
        }

    def _run_factory(self, config=None):
        if not config:
            config = self._get_run_configs()

        experiment_id = config.get("experiment_id", None)
        if not experiment_id:
            experiment_id = self._experiment_factory('test exp')
            config["experiment_id"] = experiment_id

        run = models.SqlRun(**config)
        self.session.add(run)

        return run

    def test_create_run(self):
        experiment_id = self._experiment_factory('test_create_run')
        expected = self._get_run_configs('booyya', experiment_id=experiment_id)

        tags = [RunTag('3', '4'), RunTag('1', '2')]
        actual = self.store.create_run(
            expected["experiment_id"], expected["user_id"], expected["name"],
            SourceType.from_string(expected["source_type"]),
            expected["source_name"], expected["entry_point_name"],
            expected["start_time"], expected["source_version"], tags, None)

        self.assertEqual(actual.info.experiment_id, expected["experiment_id"])
        self.assertEqual(actual.info.user_id, expected["user_id"])
        self.assertEqual(actual.info.name, 'booyya')
        self.assertEqual(actual.info.source_type,
                         SourceType.from_string(expected["source_type"]))
        self.assertEqual(actual.info.source_name, expected["source_name"])
        self.assertEqual(actual.info.source_version,
                         expected["source_version"])
        self.assertEqual(actual.info.entry_point_name,
                         expected["entry_point_name"])
        self.assertEqual(actual.info.start_time, expected["start_time"])
        self.assertEqual(len(actual.data.tags), 3)

        name_tag = models.SqlTag(key='mlflow.runName',
                                 value='booyya').to_mlflow_entity()
        self.assertListEqual(actual.data.tags, tags + [name_tag])

    def test_create_run_with_parent_id(self):
        exp = self._experiment_factory('test_create_run_with_parent_id')
        expected = self._get_run_configs('booyya', experiment_id=exp)

        tags = [RunTag('3', '4'), RunTag('1', '2')]
        actual = self.store.create_run(
            expected["experiment_id"], expected["user_id"], expected["name"],
            SourceType.from_string(expected["source_type"]),
            expected["source_name"], expected["entry_point_name"],
            expected["start_time"], expected["source_version"], tags,
            "parent_uuid_5")

        self.assertEqual(actual.info.experiment_id, expected["experiment_id"])
        self.assertEqual(actual.info.user_id, expected["user_id"])
        self.assertEqual(actual.info.name, 'booyya')
        self.assertEqual(actual.info.source_type,
                         SourceType.from_string(expected["source_type"]))
        self.assertEqual(actual.info.source_name, expected["source_name"])
        self.assertEqual(actual.info.source_version,
                         expected["source_version"])
        self.assertEqual(actual.info.entry_point_name,
                         expected["entry_point_name"])
        self.assertEqual(actual.info.start_time, expected["start_time"])
        self.assertEqual(len(actual.data.tags), 4)

        name_tag = models.SqlTag(key='mlflow.runName',
                                 value='booyya').to_mlflow_entity()
        parent_id_tag = models.SqlTag(
            key='mlflow.parentRunId',
            value='parent_uuid_5').to_mlflow_entity()
        self.assertListEqual(actual.data.tags,
                             tags + [parent_id_tag, name_tag])

    def test_to_mlflow_entity(self):
        run = self._run_factory()
        run = run.to_mlflow_entity()

        self.assertIsInstance(run.info, entities.RunInfo)
        self.assertIsInstance(run.data, entities.RunData)

        for metric in run.data.metrics:
            self.assertIsInstance(metric, entities.Metric)

        for param in run.data.params:
            self.assertIsInstance(param, entities.Param)

        for tag in run.data.tags:
            self.assertIsInstance(tag, entities.RunTag)

    def test_delete_run(self):
        run = self._run_factory()
        self.session.commit()

        run_uuid = run.run_uuid
        self.store.delete_run(run_uuid)
        actual = self.session.query(
            models.SqlRun).filter_by(run_uuid=run_uuid).first()
        self.assertEqual(actual.lifecycle_stage,
                         entities.LifecycleStage.DELETED)

        deleted_run = self.store.get_run(run_uuid)
        self.assertEqual(actual.run_uuid, deleted_run.info.run_uuid)

    def test_log_metric(self):
        run = self._run_factory()

        self.session.commit()

        tkey = 'blahmetric'
        tval = 100.0
        metric = entities.Metric(tkey, tval, int(time.time()))
        metric2 = entities.Metric(tkey, tval, int(time.time()) + 2)
        self.store.log_metric(run.run_uuid, metric)
        self.store.log_metric(run.run_uuid, metric2)

        actual = self.session.query(models.SqlMetric).filter_by(key=tkey,
                                                                value=tval)

        self.assertIsNotNone(actual)

        run = self.store.get_run(run.run_uuid)

        # SQL store _get_run method returns full history of recorded metrics.
        # Should return duplicates as well
        # MLflow RunData contains only the last reported values for metrics.
        sql_run_metrics = self.store._get_run(run.info.run_uuid).metrics
        self.assertEqual(2, len(sql_run_metrics))
        self.assertEqual(1, len(run.data.metrics))

        found = False
        for m in run.data.metrics:
            if m.key == tkey and m.value == tval:
                found = True

        self.assertTrue(found)

    def test_log_metric_uniqueness(self):
        run = self._run_factory()

        self.session.commit()

        tkey = 'blahmetric'
        tval = 100.0
        metric = entities.Metric(tkey, tval, int(time.time()))
        metric2 = entities.Metric(tkey, 1.02, int(time.time()))
        self.store.log_metric(run.run_uuid, metric)

        with self.assertRaises(MlflowException) as e:
            self.store.log_metric(run.run_uuid, metric2)
        self.assertIn("must be unique. Metric already logged value",
                      e.exception.message)

    def test_log_null_metric(self):
        run = self._run_factory()

        self.session.commit()

        tkey = 'blahmetric'
        tval = None
        metric = entities.Metric(tkey, tval, int(time.time()))

        with self.assertRaises(MlflowException) as e:
            self.store.log_metric(run.run_uuid, metric)
        self.assertIn("Log metric request failed for run ID=",
                      e.exception.message)
        self.assertIn("IntegrityError", e.exception.message)

    def test_log_param(self):
        run = self._run_factory()

        self.session.commit()

        tkey = 'blahmetric'
        tval = '100.0'
        param = entities.Param(tkey, tval)
        param2 = entities.Param('new param', 'new key')
        self.store.log_param(run.run_uuid, param)
        self.store.log_param(run.run_uuid, param2)

        actual = self.session.query(models.SqlParam).filter_by(key=tkey,
                                                               value=tval)
        self.assertIsNotNone(actual)

        run = self.store.get_run(run.run_uuid)
        self.assertEqual(2, len(run.data.params))

        found = False
        for m in run.data.params:
            if m.key == tkey and m.value == tval:
                found = True

        self.assertTrue(found)

    def test_log_param_uniqueness(self):
        run = self._run_factory()

        self.session.commit()

        tkey = 'blahmetric'
        tval = '100.0'
        param = entities.Param(tkey, tval)
        param2 = entities.Param(tkey, 'newval')
        self.store.log_param(run.run_uuid, param)

        with self.assertRaises(MlflowException) as e:
            self.store.log_param(run.run_uuid, param2)
        self.assertIn("Changing param value is not allowed. Param with key=",
                      e.exception.message)

    def test_log_null_param(self):
        run = self._run_factory()

        self.session.commit()

        tkey = 'blahmetric'
        tval = None
        param = entities.Param(tkey, tval)

        with self.assertRaises(MlflowException) as e:
            self.store.log_param(run.run_uuid, param)
        self.assertIn("Log param request failed for run ID=",
                      e.exception.message)
        self.assertIn("IntegrityError", e.exception.message)

    def test_set_tag(self):
        run = self._run_factory()

        self.session.commit()

        tkey = 'test tag'
        tval = 'a boogie'
        tag = entities.RunTag(tkey, tval)
        self.store.set_tag(run.run_uuid, tag)

        actual = self.session.query(models.SqlTag).filter_by(key=tkey,
                                                             value=tval)

        self.assertIsNotNone(actual)

        run = self.store.get_run(run.run_uuid)

        found = False
        for m in run.data.tags:
            if m.key == tkey and m.value == tval:
                found = True

        self.assertTrue(found)

    def test_get_metric(self):
        run = self._run_factory()
        self.session.commit()

        for expected in run.metrics:
            actual = self.store.get_metric(run.run_uuid, expected.key)
            self.assertEqual(expected.key, actual.key)
            self.assertEqual(expected.value, actual.value)
            self.assertEqual(expected.timestamp, actual.timestamp)

    def test_get_metric_history(self):
        run = self._run_factory()
        self.session.commit()
        key = 'test'
        expected = [
            models.SqlMetric(key=key, value=0.6,
                             timestamp=1).to_mlflow_entity(),
            models.SqlMetric(key=key, value=0.7,
                             timestamp=2).to_mlflow_entity()
        ]

        for metric in expected:
            self.store.log_metric(run.run_uuid, metric)

        actual = self.store.get_metric_history(run.run_uuid, key)

        self.assertSequenceEqual([(m.key, m.value, m.timestamp)
                                  for m in expected],
                                 [(m.key, m.value, m.timestamp)
                                  for m in actual])

    def test_get_param(self):
        run = self._run_factory()
        self.session.commit()

        for expected in run.params:
            actual = self.store.get_param(run.run_uuid, expected.key)
            self.assertEqual(expected.key, actual.key)
            self.assertEqual(expected.value, actual.value)

    def test_list_run_infos(self):
        experiment_id = self._experiment_factory('test_exp')
        r1 = self._run_factory(self._get_run_configs('t1',
                                                     experiment_id)).run_uuid
        r2 = self._run_factory(self._get_run_configs('t2',
                                                     experiment_id)).run_uuid

        def _runs(experiment_id, view_type):
            return [
                r.run_uuid
                for r in self.store.list_run_infos(experiment_id, view_type)
            ]

        self.assertSequenceEqual([r1, r2], _runs(experiment_id, ViewType.ALL))
        self.assertSequenceEqual([r1, r2],
                                 _runs(experiment_id, ViewType.ACTIVE_ONLY))
        self.assertEqual(0, len(_runs(experiment_id, ViewType.DELETED_ONLY)))

        self.store.delete_run(r1)
        self.assertSequenceEqual([r1, r2], _runs(experiment_id, ViewType.ALL))
        self.assertSequenceEqual([r2],
                                 _runs(experiment_id, ViewType.ACTIVE_ONLY))
        self.assertSequenceEqual([r1],
                                 _runs(experiment_id, ViewType.DELETED_ONLY))

    def test_rename_experiment(self):
        new_name = 'new name'
        experiment_id = self._experiment_factory('test name')
        self.store.rename_experiment(experiment_id, new_name)

        renamed_experiment = self.store.get_experiment(experiment_id)

        self.assertEqual(renamed_experiment.name, new_name)

    def test_update_run_info(self):
        run = self._run_factory()
        new_status = entities.RunStatus.FINISHED
        endtime = int(time.time())

        actual = self.store.update_run_info(run.run_uuid, new_status, endtime)

        self.assertEqual(actual.status, new_status)
        self.assertEqual(actual.end_time, endtime)

    def test_restore_experiment(self):
        experiment_id = self._experiment_factory('helloexp')
        exp = self.store.get_experiment(experiment_id)
        self.assertEqual(exp.lifecycle_stage, entities.LifecycleStage.ACTIVE)

        experiment_id = exp.experiment_id
        self.store.delete_experiment(experiment_id)

        deleted = self.store.get_experiment(experiment_id)
        self.assertEqual(deleted.experiment_id, experiment_id)
        self.assertEqual(deleted.lifecycle_stage,
                         entities.LifecycleStage.DELETED)

        self.store.restore_experiment(exp.experiment_id)
        restored = self.store.get_experiment(exp.experiment_id)
        self.assertEqual(restored.experiment_id, experiment_id)
        self.assertEqual(restored.lifecycle_stage,
                         entities.LifecycleStage.ACTIVE)

    def test_delete_restore_run(self):
        run = self._run_factory()
        self.assertEqual(run.lifecycle_stage, entities.LifecycleStage.ACTIVE)

        run_uuid = run.run_uuid

        with self.assertRaises(MlflowException) as e:
            self.store.restore_run(run_uuid)
        self.assertIn("must be in 'deleted' state", e.exception.message)

        self.store.delete_run(run_uuid)
        with self.assertRaises(MlflowException) as e:
            self.store.delete_run(run_uuid)
        self.assertIn("must be in 'active' state", e.exception.message)

        deleted = self.store.get_run(run_uuid)
        self.assertEqual(deleted.info.run_uuid, run_uuid)
        self.assertEqual(deleted.info.lifecycle_stage,
                         entities.LifecycleStage.DELETED)

        self.store.restore_run(run_uuid)
        with self.assertRaises(MlflowException) as e:
            self.store.restore_run(run_uuid)
            self.assertIn("must be in 'deleted' state", e.exception.message)
        restored = self.store.get_run(run_uuid)
        self.assertEqual(restored.info.run_uuid, run_uuid)
        self.assertEqual(restored.info.lifecycle_stage,
                         entities.LifecycleStage.ACTIVE)

    def test_error_logging_to_deleted_run(self):
        exp = self._experiment_factory('error_logging')
        run_uuid = self._run_factory(
            self._get_run_configs(experiment_id=exp)).run_uuid

        self.store.delete_run(run_uuid)
        self.assertEqual(
            self.store.get_run(run_uuid).info.lifecycle_stage,
            entities.LifecycleStage.DELETED)
        with self.assertRaises(MlflowException) as e:
            self.store.log_param(run_uuid, entities.Param("p1345", "v1"))
        self.assertIn("must be in 'active' state", e.exception.message)

        with self.assertRaises(MlflowException) as e:
            self.store.log_metric(run_uuid, entities.Metric("m1345", 1.0, 123))
        self.assertIn("must be in 'active' state", e.exception.message)

        with self.assertRaises(MlflowException) as e:
            self.store.set_tag(run_uuid, entities.RunTag("t1345", "tv1"))
        self.assertIn("must be in 'active' state", e.exception.message)

        # restore this run and try again
        self.store.restore_run(run_uuid)
        self.assertEqual(
            self.store.get_run(run_uuid).info.lifecycle_stage,
            entities.LifecycleStage.ACTIVE)
        self.store.log_param(run_uuid, entities.Param("p1345", "v22"))
        self.store.log_metric(run_uuid,
                              entities.Metric("m1345", 34.0,
                                              85))  # earlier timestamp
        self.store.set_tag(run_uuid, entities.RunTag("t1345", "tv44"))

        p = self.store.get_param(run_uuid, "p1345")
        self.assertEqual(p.key, "p1345")
        self.assertEqual(p.value, "v22")
        m = self.store.get_metric(run_uuid, "m1345")
        self.assertEqual(m.key, "m1345")
        self.assertEqual(m.value, 34.0)
        run = self.store.get_run(run_uuid)
        self.assertEqual([("p1345", "v22")],
                         [(p.key, p.value)
                          for p in run.data.params if p.key == "p1345"])
        self.assertEqual([("m1345", 34.0, 85)],
                         [(m.key, m.value, m.timestamp)
                          for m in run.data.metrics if m.key == "m1345"])
        self.assertEqual([("t1345", "tv44")],
                         [(t.key, t.value)
                          for t in run.data.tags if t.key == "t1345"])

    # Tests for Search API
    def _search(self,
                experiment_id,
                metrics_expressions=None,
                param_expressions=None,
                run_view_type=ViewType.ALL):
        conditions = (metrics_expressions or []) + (param_expressions or [])
        return [
            r.info.run_uuid for r in self.store.search_runs(
                [experiment_id], conditions, run_view_type)
        ]

    def _param_expression(self, key, comparator, val):
        expr = SearchExpression()
        expr.parameter.key = key
        expr.parameter.string.comparator = comparator
        expr.parameter.string.value = val
        return expr

    def _metric_expression(self, key, comparator, val):
        expr = SearchExpression()
        expr.metric.key = key
        expr.metric.double.comparator = comparator
        expr.metric.double.value = val
        return expr

    def test_search_vanilla(self):
        exp = self._experiment_factory('search_vanilla')
        runs = [
            self._run_factory(self._get_run_configs('r_%d' % r, exp)).run_uuid
            for r in range(3)
        ]

        self.assertSequenceEqual(runs,
                                 self._search(exp, run_view_type=ViewType.ALL))
        self.assertSequenceEqual(
            runs, self._search(exp, run_view_type=ViewType.ACTIVE_ONLY))
        self.assertSequenceEqual([],
                                 self._search(
                                     exp, run_view_type=ViewType.DELETED_ONLY))

        first = runs[0]

        self.store.delete_run(first)
        self.assertSequenceEqual(runs,
                                 self._search(exp, run_view_type=ViewType.ALL))
        self.assertSequenceEqual(
            runs[1:], self._search(exp, run_view_type=ViewType.ACTIVE_ONLY))
        self.assertSequenceEqual([first],
                                 self._search(
                                     exp, run_view_type=ViewType.DELETED_ONLY))

        self.store.restore_run(first)
        self.assertSequenceEqual(runs,
                                 self._search(exp, run_view_type=ViewType.ALL))
        self.assertSequenceEqual(
            runs, self._search(exp, run_view_type=ViewType.ACTIVE_ONLY))
        self.assertSequenceEqual([],
                                 self._search(
                                     exp, run_view_type=ViewType.DELETED_ONLY))

    def test_search_params(self):
        experiment_id = self._experiment_factory('search_params')
        r1 = self._run_factory(self._get_run_configs('r1',
                                                     experiment_id)).run_uuid
        r2 = self._run_factory(self._get_run_configs('r2',
                                                     experiment_id)).run_uuid

        self.store.log_param(r1, entities.Param('generic_param', 'p_val'))
        self.store.log_param(r2, entities.Param('generic_param', 'p_val'))

        self.store.log_param(r1, entities.Param('generic_2', 'some value'))
        self.store.log_param(r2, entities.Param('generic_2', 'another value'))

        self.store.log_param(r1, entities.Param('p_a', 'abc'))
        self.store.log_param(r2, entities.Param('p_b', 'ABC'))

        # test search returns both runs
        expr = self._param_expression("generic_param", "=", "p_val")
        self.assertSequenceEqual([r1, r2],
                                 self._search(experiment_id,
                                              param_expressions=[expr]))

        # test search returns appropriate run (same key different values per run)
        expr = self._param_expression("generic_2", "=", "some value")
        self.assertSequenceEqual([r1],
                                 self._search(experiment_id,
                                              param_expressions=[expr]))
        expr = self._param_expression("generic_2", "=", "another value")
        self.assertSequenceEqual([r2],
                                 self._search(experiment_id,
                                              param_expressions=[expr]))

        expr = self._param_expression("generic_param", "=", "wrong_val")
        self.assertSequenceEqual([],
                                 self._search(experiment_id,
                                              param_expressions=[expr]))

        expr = self._param_expression("generic_param", "!=", "p_val")
        self.assertSequenceEqual([],
                                 self._search(experiment_id,
                                              param_expressions=[expr]))

        expr = self._param_expression("generic_param", "!=", "wrong_val")
        self.assertSequenceEqual([r1, r2],
                                 self._search(experiment_id,
                                              param_expressions=[expr]))
        expr = self._param_expression("generic_2", "!=", "wrong_val")
        self.assertSequenceEqual([r1, r2],
                                 self._search(experiment_id,
                                              param_expressions=[expr]))

        expr = self._param_expression("p_a", "=", "abc")
        self.assertSequenceEqual([r1],
                                 self._search(experiment_id,
                                              param_expressions=[expr]))

        expr = self._param_expression("p_b", "=", "ABC")
        self.assertSequenceEqual([r2],
                                 self._search(experiment_id,
                                              param_expressions=[expr]))

    def test_search_metrics(self):
        experiment_id = self._experiment_factory('search_params')
        r1 = self._run_factory(self._get_run_configs('r1',
                                                     experiment_id)).run_uuid
        r2 = self._run_factory(self._get_run_configs('r2',
                                                     experiment_id)).run_uuid

        self.store.log_metric(r1, entities.Metric("common", 1.0, 1))
        self.store.log_metric(r2, entities.Metric("common", 1.0, 1))

        self.store.log_metric(r1, entities.Metric("measure_a", 1.0, 1))
        self.store.log_metric(r2, entities.Metric("measure_a", 200.0, 2))
        self.store.log_metric(r2, entities.Metric("measure_a", 400.0, 3))

        self.store.log_metric(r1, entities.Metric("m_a", 2.0, 2))
        self.store.log_metric(r2, entities.Metric("m_b", 3.0, 2))
        self.store.log_metric(r2, entities.Metric("m_b", 4.0,
                                                  8))  # this is last timestamp
        self.store.log_metric(r2, entities.Metric("m_b", 8.0, 3))

        expr = self._metric_expression("common", "=", 1.0)
        self.assertSequenceEqual([r1, r2],
                                 self._search(experiment_id,
                                              param_expressions=[expr]))

        expr = self._metric_expression("common", ">", 0.0)
        self.assertSequenceEqual([r1, r2],
                                 self._search(experiment_id,
                                              param_expressions=[expr]))

        expr = self._metric_expression("common", ">=", 0.0)
        self.assertSequenceEqual([r1, r2],
                                 self._search(experiment_id,
                                              param_expressions=[expr]))

        expr = self._metric_expression("common", "<", 4.0)
        self.assertSequenceEqual([r1, r2],
                                 self._search(experiment_id,
                                              param_expressions=[expr]))

        expr = self._metric_expression("common", "<=", 4.0)
        self.assertSequenceEqual([r1, r2],
                                 self._search(experiment_id,
                                              param_expressions=[expr]))

        expr = self._metric_expression("common", "!=", 1.0)
        self.assertSequenceEqual([],
                                 self._search(experiment_id,
                                              param_expressions=[expr]))

        expr = self._metric_expression("common", ">=", 3.0)
        self.assertSequenceEqual([],
                                 self._search(experiment_id,
                                              param_expressions=[expr]))

        expr = self._metric_expression("common", "<=", 0.75)
        self.assertSequenceEqual([],
                                 self._search(experiment_id,
                                              param_expressions=[expr]))

        # tests for same metric name across runs with different values and timestamps
        expr = self._metric_expression("measure_a", ">", 0.0)
        self.assertSequenceEqual([r1, r2],
                                 self._search(experiment_id,
                                              param_expressions=[expr]))

        expr = self._metric_expression("measure_a", "<", 50.0)
        self.assertSequenceEqual([r1],
                                 self._search(experiment_id,
                                              param_expressions=[expr]))

        expr = self._metric_expression("measure_a", "<", 1000.0)
        self.assertSequenceEqual([r1, r2],
                                 self._search(experiment_id,
                                              param_expressions=[expr]))

        expr = self._metric_expression("measure_a", "!=", -12.0)
        self.assertSequenceEqual([r1, r2],
                                 self._search(experiment_id,
                                              param_expressions=[expr]))

        expr = self._metric_expression("measure_a", ">", 50.0)
        self.assertSequenceEqual([r2],
                                 self._search(experiment_id,
                                              param_expressions=[expr]))

        expr = self._metric_expression("measure_a", "=", 1.0)
        self.assertSequenceEqual([r1],
                                 self._search(experiment_id,
                                              param_expressions=[expr]))

        expr = self._metric_expression("measure_a", "=", 400.0)
        self.assertSequenceEqual([r2],
                                 self._search(experiment_id,
                                              param_expressions=[expr]))

        # test search with unique metric keys
        expr = self._metric_expression("m_a", ">", 1.0)
        self.assertSequenceEqual([r1],
                                 self._search(experiment_id,
                                              param_expressions=[expr]))

        expr = self._metric_expression("m_b", ">", 1.0)
        self.assertSequenceEqual([r2],
                                 self._search(experiment_id,
                                              param_expressions=[expr]))

        # there is a recorded metric this threshold but not last timestamp
        expr = self._metric_expression("m_b", ">", 5.0)
        self.assertSequenceEqual([],
                                 self._search(experiment_id,
                                              param_expressions=[expr]))

        # metrics matches last reported timestamp for 'm_b'
        expr = self._metric_expression("m_b", "=", 4.0)
        self.assertSequenceEqual([r2],
                                 self._search(experiment_id,
                                              param_expressions=[expr]))

    def test_search_full(self):
        experiment_id = self._experiment_factory('search_params')
        r1 = self._run_factory(self._get_run_configs('r1',
                                                     experiment_id)).run_uuid
        r2 = self._run_factory(self._get_run_configs('r2',
                                                     experiment_id)).run_uuid

        self.store.log_param(r1, entities.Param('generic_param', 'p_val'))
        self.store.log_param(r2, entities.Param('generic_param', 'p_val'))

        self.store.log_param(r1, entities.Param('p_a', 'abc'))
        self.store.log_param(r2, entities.Param('p_b', 'ABC'))

        self.store.log_metric(r1, entities.Metric("common", 1.0, 1))
        self.store.log_metric(r2, entities.Metric("common", 1.0, 1))

        self.store.log_metric(r1, entities.Metric("m_a", 2.0, 2))
        self.store.log_metric(r2, entities.Metric("m_b", 3.0, 2))
        self.store.log_metric(r2, entities.Metric("m_b", 4.0, 8))
        self.store.log_metric(r2, entities.Metric("m_b", 8.0, 3))

        p_expr = self._param_expression("generic_param", "=", "p_val")
        m_expr = self._metric_expression("common", "=", 1.0)
        self.assertSequenceEqual([r1, r2],
                                 self._search(experiment_id,
                                              param_expressions=[p_expr],
                                              metrics_expressions=[m_expr]))

        # all params and metrics match
        p_expr = self._param_expression("generic_param", "=", "p_val")
        m1_expr = self._metric_expression("common", "=", 1.0)
        m2_expr = self._metric_expression("m_a", ">", 1.0)
        self.assertSequenceEqual([r1],
                                 self._search(
                                     experiment_id,
                                     param_expressions=[p_expr],
                                     metrics_expressions=[m1_expr, m2_expr]))

        # test with mismatch param
        p_expr = self._param_expression("random_bad_name", "=", "p_val")
        m1_expr = self._metric_expression("common", "=", 1.0)
        m2_expr = self._metric_expression("m_a", ">", 1.0)
        self.assertSequenceEqual([],
                                 self._search(
                                     experiment_id,
                                     param_expressions=[p_expr],
                                     metrics_expressions=[m1_expr, m2_expr]))

        # test with mismatch metric
        p_expr = self._param_expression("generic_param", "=", "p_val")
        m1_expr = self._metric_expression("common", "=", 1.0)
        m2_expr = self._metric_expression("m_a", ">", 100.0)
        self.assertSequenceEqual([],
                                 self._search(
                                     experiment_id,
                                     param_expressions=[p_expr],
                                     metrics_expressions=[m1_expr, m2_expr]))
예제 #17
0
 def _setup_database(self, filename=''):
     # use a static file name to initialize sqllite to test retention.
     self.store = SqlAlchemyStore(DB_URI + filename, ARTIFACT_URI)
     self.session = self.store.session
예제 #18
0
def _get_sqlalchemy_store(store_uri, artifact_uri):
    from mlflow.store.sqlalchemy_store import SqlAlchemyStore
    if artifact_uri is None:
        artifact_uri = DEFAULT_LOCAL_FILE_AND_ARTIFACT_PATH
    return SqlAlchemyStore(store_uri, artifact_uri)