def test_store_generated_schema_matches_base(tmpdir, db_url): # Create a SQLAlchemyStore against tmpfile, directly verify that tmpfile contains a # database with a valid schema SqlAlchemyStore(db_url, tmpdir.join("ARTIFACTS").strpath) engine = sqlalchemy.create_engine(db_url) mc = MigrationContext.configure(engine.connect()) diff = compare_metadata(mc, Base.metadata) assert len(diff) == 0
def test_sqlalchemystore_idempotently_generates_up_to_date_schema( tmpdir, db_url, expected_schema_file): generated_schema_file = tmpdir.join("generated-schema.sql").strpath # Repeatedly initialize a SQLAlchemyStore against the same DB URL. Initialization should # succeed and the schema should be the same. for _ in range(3): SqlAlchemyStore(db_url, tmpdir.join("ARTIFACTS").strpath) dump_db_schema(db_url, dst_file=generated_schema_file) _assert_schema_files_equal(generated_schema_file, expected_schema_file)
def sqlite_store(): fd, temp_dbfile = tempfile.mkstemp() # Close handle immediately so that we can remove the file later on in Windows os.close(fd) db_uri = "sqlite:///%s" % temp_dbfile store = SqlAlchemyStore(db_uri, "artifact_folder") yield (store, db_uri) os.remove(temp_dbfile) shutil.rmtree("artifact_folder")
def dump_sqlalchemy_store_schema(dst_file): db_tmpdir = tempfile.mkdtemp() try: path = os.path.join(db_tmpdir, "db_file") db_url = "sqlite:///%s" % path SqlAlchemyStore(db_url, db_tmpdir) dump_db_schema(db_url, dst_file) finally: shutil.rmtree(db_tmpdir)
def test_list_experiments(view_type, tmpdir): sqlite_uri = "sqlite:///" + os.path.join(tmpdir.strpath, "test.db") store = SqlAlchemyStore(sqlite_uri, default_artifact_root=tmpdir.strpath) num_experiments = SEARCH_MAX_RESULTS_DEFAULT + 1 if view_type == ViewType.DELETED_ONLY: # Delete the default experiment mlflow.tracking.MlflowClient(sqlite_uri).delete_experiment("0") # This is a bit hacky but much faster than creating experiments one by one with # `mlflow.create_experiment` with store.ManagedSessionMaker() as session: lifecycle_stages = LifecycleStage.view_type_to_stages(view_type) experiments = [ SqlExperiment( name=f"exp_{i + 1}", lifecycle_stage=random.choice(lifecycle_stages), artifact_location=tmpdir.strpath, ) for i in range(num_experiments - 1) ] session.add_all(experiments) try: url, process = _init_server(sqlite_uri, root_artifact_uri=tmpdir.strpath) print("In process %s", process) mlflow.set_tracking_uri(url) # `max_results` is unspecified assert len(mlflow.list_experiments(view_type)) == num_experiments # `max_results` is larger than the number of experiments in the database assert len(mlflow.list_experiments(view_type, num_experiments + 1)) == num_experiments # `max_results` is equal to the number of experiments in the database assert len(mlflow.list_experiments(view_type, num_experiments)) == num_experiments # `max_results` is smaller than the number of experiments in the database assert len(mlflow.list_experiments(view_type, num_experiments - 1)) == num_experiments - 1 finally: process.terminate()
def test_store_generated_schema_matches_base(tmpdir, db_url): # Create a SQLAlchemyStore against tmpfile, directly verify that tmpfile contains a # database with a valid schema SqlAlchemyStore(db_url, tmpdir.join("ARTIFACTS").strpath) engine = sqlalchemy.create_engine(db_url) mc = MigrationContext.configure(engine.connect()) diff = compare_metadata(mc, Base.metadata) # `diff` contains several `remove_index` operations because `Base.metadata` does not contain # index metadata but `mc` does. Note this doesn't mean the MLflow database is missing indexes # as tested in `test_create_index_on_run_uuid`. diff = [d for d in diff if d[0] != "remove_index"] assert len(diff) == 0
def test_create_index_on_run_uuid(tmpdir, db_url): # Test for mlflow/store/db_migrations/versions/bd07f7e963c5_create_index_on_run_uuid.py SqlAlchemyStore(db_url, tmpdir.join("ARTIFACTS").strpath) with sqlite3.connect(db_url[len("sqlite:///"):]) as conn: cursor = conn.cursor() cursor.execute("SELECT name FROM sqlite_master WHERE type = 'index'") all_index_names = [r[0] for r in cursor.fetchall()] run_uuid_index_names = { "index_params_run_uuid", "index_metrics_run_uuid", "index_latest_metrics_run_uuid", "index_tags_run_uuid", } assert run_uuid_index_names.issubset(all_index_names)
def test_sqlalchemy_store_detects_schema_mismatch(tmpdir, db_url): # pylint: disable=unused-argument def _assert_invalid_schema(engine): with pytest.raises(MlflowException) as ex: SqlAlchemyStore._verify_schema(engine) assert ex.message.contains("Detected out-of-date database schema.") # Initialize an empty database & verify that we detect a schema mismatch engine = sqlalchemy.create_engine(db_url) _assert_invalid_schema(engine) # Create legacy tables, verify schema is still out of date InitialBase.metadata.create_all(engine) _assert_invalid_schema(engine) # Run each migration. Until the last one, schema should be out of date config = _get_alembic_config(db_url) script = ScriptDirectory.from_config(config) revisions = list(script.walk_revisions()) revisions.reverse() for rev in revisions[:-1]: command.upgrade(config, rev.revision) _assert_invalid_schema(engine) # Run migrations, schema verification should now pass invoke_cli_runner(mlflow.db.commands, ['upgrade', db_url]) SqlAlchemyStore._verify_schema(engine)
def _get_sqlalchemy_store(cls, store_uri): from mlflow.store.model_registry.sqlalchemy_store import SqlAlchemyStore return SqlAlchemyStore(store_uri)
def _get_sqlalchemy_store(cls, store_uri, artifact_uri): from mlflow.store.tracking.sqlalchemy_store import SqlAlchemyStore return SqlAlchemyStore(store_uri, artifact_uri)
def _get_sqlalchemy_store(store_uri, artifact_uri): from mlflow.store.tracking.sqlalchemy_store import SqlAlchemyStore if artifact_uri is None: artifact_uri = DEFAULT_LOCAL_FILE_AND_ARTIFACT_PATH return SqlAlchemyStore(store_uri, artifact_uri)
def _assert_invalid_schema(engine): with pytest.raises(MlflowException) as ex: SqlAlchemyStore._verify_schema(engine) assert ex.message.contains("Detected out-of-date database schema.")