def create_experiment(self, name, artifact_location=None): if name is None or name == "": raise MlflowException("Invalid experiment name", INVALID_PARAMETER_VALUE) with self.ManagedSessionMaker() as session: try: experiment = SqlExperiment( name=name, lifecycle_stage=LifecycleStage.ACTIVE, artifact_location=artifact_location, ) session.add(experiment) if not artifact_location: # this requires a double write. The first one to generate an autoincrement-ed ID eid = session.query(SqlExperiment).filter_by( name=name).first().experiment_id experiment.artifact_location = self._get_artifact_location( eid) except sqlalchemy.exc.IntegrityError as e: raise MlflowException( "Experiment(name={}) already exists. " "Error: {}".format(name, str(e)), RESOURCE_ALREADY_EXISTS, ) session.flush() return str(experiment.experiment_id)
def test_list_experiments(view_type, tmpdir): sqlite_uri = "sqlite:///" + os.path.join(tmpdir.strpath, "test.db") store = SqlAlchemyStore(sqlite_uri, default_artifact_root=tmpdir.strpath) num_experiments = SEARCH_MAX_RESULTS_DEFAULT + 1 if view_type == ViewType.DELETED_ONLY: # Delete the default experiment mlflow.tracking.MlflowClient(sqlite_uri).delete_experiment("0") # This is a bit hacky but much faster than creating experiments one by one with # `mlflow.create_experiment` with store.ManagedSessionMaker() as session: lifecycle_stages = LifecycleStage.view_type_to_stages(view_type) experiments = [ SqlExperiment( name=f"exp_{i + 1}", lifecycle_stage=random.choice(lifecycle_stages), artifact_location=tmpdir.strpath, ) for i in range(num_experiments - 1) ] session.add_all(experiments) try: url, process = _init_server(sqlite_uri, root_artifact_uri=tmpdir.strpath) print("In process %s", process) mlflow.set_tracking_uri(url) # `max_results` is unspecified assert len(mlflow.list_experiments(view_type)) == num_experiments # `max_results` is larger than the number of experiments in the database assert len(mlflow.list_experiments(view_type, num_experiments + 1)) == num_experiments # `max_results` is equal to the number of experiments in the database assert len(mlflow.list_experiments(view_type, num_experiments)) == num_experiments # `max_results` is smaller than the number of experiments in the database assert len(mlflow.list_experiments(view_type, num_experiments - 1)) == num_experiments - 1 finally: process.terminate()