示例#1
0
def test_list_experiments(view_type, tmpdir):
    sqlite_uri = "sqlite:///" + os.path.join(tmpdir.strpath, "test.db")
    store = SqlAlchemyStore(sqlite_uri, default_artifact_root=tmpdir.strpath)

    num_experiments = SEARCH_MAX_RESULTS_DEFAULT + 1

    if view_type == ViewType.DELETED_ONLY:
        # Delete the default experiment
        mlflow.tracking.MlflowClient(sqlite_uri).delete_experiment("0")

    # This is a bit hacky but much faster than creating experiments one by one with
    # `mlflow.create_experiment`
    with store.ManagedSessionMaker() as session:
        lifecycle_stages = LifecycleStage.view_type_to_stages(view_type)
        experiments = [
            SqlExperiment(
                name=f"exp_{i + 1}",
                lifecycle_stage=random.choice(lifecycle_stages),
                artifact_location=tmpdir.strpath,
            ) for i in range(num_experiments - 1)
        ]
        session.add_all(experiments)

    try:
        url, process = _init_server(sqlite_uri,
                                    root_artifact_uri=tmpdir.strpath)
        print("In process %s", process)
        mlflow.set_tracking_uri(url)
        # `max_results` is unspecified
        assert len(mlflow.list_experiments(view_type)) == num_experiments
        # `max_results` is larger than the number of experiments in the database
        assert len(mlflow.list_experiments(view_type, num_experiments +
                                           1)) == num_experiments
        # `max_results` is equal to the number of experiments in the database
        assert len(mlflow.list_experiments(view_type,
                                           num_experiments)) == num_experiments
        # `max_results` is smaller than the number of experiments in the database
        assert len(mlflow.list_experiments(view_type, num_experiments -
                                           1)) == num_experiments - 1
    finally:
        process.terminate()
示例#2
0
    # whatever reason. Windows needs uri like 'sqlite:///C:/path/to/my/file' whereas posix expects
    # sqlite://///path/to/my/file
    prefix = "sqlite://" if sys.platform == "win32" else "sqlite:////"
    return prefix + path


# Backend store URIs to test against
BACKEND_URIS = [
    _get_sqlite_uri(),  # SqlAlchemy
    path_to_local_file_uri(os.path.join(SUITE_ROOT_DIR, "file_store_root")),  # FileStore
]

# Map of backend URI to tuple (server URL, Process). We populate this map by constructing
# a server per backend URI
BACKEND_URI_TO_SERVER_URL_AND_PROC = {
    uri: _init_server(backend_uri=uri, root_artifact_uri=SUITE_ARTIFACT_ROOT_DIR)
    for uri in BACKEND_URIS
}


def pytest_generate_tests(metafunc):
    """
    Automatically parametrize each each fixture/test that depends on `backend_store_uri` with the
    list of backend store URIs.
    """
    if "backend_store_uri" in metafunc.fixturenames:
        metafunc.parametrize("backend_store_uri", BACKEND_URIS)


@pytest.fixture(scope="module", autouse=True)
def server_urls():