def create_experiment(self, name, artifact_location=None):
        if name is None or name == "":
            raise MlflowException("Invalid experiment name",
                                  INVALID_PARAMETER_VALUE)

        with self.ManagedSessionMaker() as session:
            try:
                experiment = SqlExperiment(
                    name=name,
                    lifecycle_stage=LifecycleStage.ACTIVE,
                    artifact_location=artifact_location,
                )
                session.add(experiment)
                if not artifact_location:
                    # this requires a double write. The first one to generate an autoincrement-ed ID
                    eid = session.query(SqlExperiment).filter_by(
                        name=name).first().experiment_id
                    experiment.artifact_location = self._get_artifact_location(
                        eid)
            except sqlalchemy.exc.IntegrityError as e:
                raise MlflowException(
                    "Experiment(name={}) already exists. "
                    "Error: {}".format(name, str(e)),
                    RESOURCE_ALREADY_EXISTS,
                )

            session.flush()
            return str(experiment.experiment_id)
예제 #2
0
def test_list_experiments(view_type, tmpdir):
    sqlite_uri = "sqlite:///" + os.path.join(tmpdir.strpath, "test.db")
    store = SqlAlchemyStore(sqlite_uri, default_artifact_root=tmpdir.strpath)

    num_experiments = SEARCH_MAX_RESULTS_DEFAULT + 1

    if view_type == ViewType.DELETED_ONLY:
        # Delete the default experiment
        mlflow.tracking.MlflowClient(sqlite_uri).delete_experiment("0")

    # This is a bit hacky but much faster than creating experiments one by one with
    # `mlflow.create_experiment`
    with store.ManagedSessionMaker() as session:
        lifecycle_stages = LifecycleStage.view_type_to_stages(view_type)
        experiments = [
            SqlExperiment(
                name=f"exp_{i + 1}",
                lifecycle_stage=random.choice(lifecycle_stages),
                artifact_location=tmpdir.strpath,
            ) for i in range(num_experiments - 1)
        ]
        session.add_all(experiments)

    try:
        url, process = _init_server(sqlite_uri,
                                    root_artifact_uri=tmpdir.strpath)
        print("In process %s", process)
        mlflow.set_tracking_uri(url)
        # `max_results` is unspecified
        assert len(mlflow.list_experiments(view_type)) == num_experiments
        # `max_results` is larger than the number of experiments in the database
        assert len(mlflow.list_experiments(view_type, num_experiments +
                                           1)) == num_experiments
        # `max_results` is equal to the number of experiments in the database
        assert len(mlflow.list_experiments(view_type,
                                           num_experiments)) == num_experiments
        # `max_results` is smaller than the number of experiments in the database
        assert len(mlflow.list_experiments(view_type, num_experiments -
                                           1)) == num_experiments - 1
    finally:
        process.terminate()