Esempio n. 1
0
def test_list_experiments(view_type, tmpdir):
    sqlite_uri = "sqlite:///" + os.path.join(tmpdir.strpath, "test.db")
    store = SqlAlchemyStore(sqlite_uri, default_artifact_root=tmpdir.strpath)

    num_experiments = SEARCH_MAX_RESULTS_DEFAULT + 1

    if view_type == ViewType.DELETED_ONLY:
        # Delete the default experiment
        mlflow.tracking.MlflowClient(sqlite_uri).delete_experiment("0")

    # This is a bit hacky but much faster than creating experiments one by one with
    # `mlflow.create_experiment`
    with store.ManagedSessionMaker() as session:
        lifecycle_stages = LifecycleStage.view_type_to_stages(view_type)
        experiments = [
            SqlExperiment(
                name=f"exp_{i + 1}",
                lifecycle_stage=random.choice(lifecycle_stages),
                artifact_location=tmpdir.strpath,
            ) for i in range(num_experiments - 1)
        ]
        session.add_all(experiments)

    try:
        url, process = _init_server(sqlite_uri,
                                    root_artifact_uri=tmpdir.strpath)
        print("In process %s", process)
        mlflow.set_tracking_uri(url)
        # `max_results` is unspecified
        assert len(mlflow.list_experiments(view_type)) == num_experiments
        # `max_results` is larger than the number of experiments in the database
        assert len(mlflow.list_experiments(view_type, num_experiments +
                                           1)) == num_experiments
        # `max_results` is equal to the number of experiments in the database
        assert len(mlflow.list_experiments(view_type,
                                           num_experiments)) == num_experiments
        # `max_results` is smaller than the number of experiments in the database
        assert len(mlflow.list_experiments(view_type, num_experiments -
                                           1)) == num_experiments - 1
    finally:
        process.terminate()
Esempio n. 2
0
def list_experiments(view):
    """
    List all experiments in the configured tracking server.
    """
    view_type = ViewType.from_string(view) if view else ViewType.ACTIVE_ONLY
    experiments = mlflow.list_experiments(view_type)
    table = [[
        exp.experiment_id,
        exp.name,
        exp.artifact_location if is_uri(exp.artifact_location) else
        os.path.abspath(exp.artifact_location),
    ] for exp in experiments]
    print(
        tabulate(sorted(table),
                 headers=["Experiment Id", "Name", "Artifact Location"]))
def test_list_experiments():
    start_run_and_log_data()
    experiments = mlflow.list_experiments(view_type=ViewType.ALL,
                                          max_results=5)
    assert len(experiments) > 0