コード例 #1
0
ファイル: mlflow.py プロジェクト: yinxiEquinor/gordo
def get_run_id(client: MlflowClient, experiment_name: str,
               model_key: str) -> str:
    """
    Get an existing or create a new run for the given model_key and experiment_name.

    The model key corresponds to a unique configuration of the model. The corresponding
    run must be manually stopped using the `mlflow.tracking.MlflowClient.set_terminated`
    method.

    Parameters
    ----------
    client: mlflow.tracking.MlflowClient
        Client with tracking uri set to AzureML if configured.
    experiment_name: str
        Name of experiment to log to.
    model_key: str
        Unique ID of model configuration.

    Returns
    -------
    run_id: str
        Unique ID of MLflow run to log to.
    """
    experiment = client.get_experiment_by_name(experiment_name)

    experiment_id = (getattr(experiment, "experiment_id") if experiment else
                     client.create_experiment(experiment_name))
    return client.create_run(experiment_id, tags={
        "model_key": model_key
    }).info.run_id
コード例 #2
0
class MlflowWriter:
    def __init__(self, experiment_name, **kwargs):
        self.client = MlflowClient(**kwargs)
        try:
            self.experiment_id = self.client.create_experiment(experiment_name)
        except Exception:
            self.experiment_id = self.client.get_experiment_by_name(
                experiment_name).experiment_id

        self.run_id = self.client.create_run(self.experiment_id).info.run_id

    def log_params_from_omegaconf_dict(self, params):
        for param_name, element in params.items():
            self._explore_recursive(param_name, element)

    def _explore_recursive(self, parent_name, element):
        if isinstance(element, DictConfig):
            for k, v in element.items():
                if isinstance(v, DictConfig) or isinstance(v, ListConfig):
                    self._explore_recursive(f"{parent_name}.{k}", v)
                else:
                    self.client.log_param(self.run_id, f"{parent_name}.{k}", v)
        elif isinstance(element, ListConfig):
            for i, v in enumerate(element):
                self.client.log_param(self.run_id, f"{parent_name}.{i}", v)

    def log_torch_model(self, model):
        with mlflow.start_run(self.run_id):
            pytorch.log_model(model, "models")

    def log_param(self, key, value):
        self.client.log_param(self.run_id, key, value)

    def log_metric(self, key, value, timestamp=None, step=None):
        self.client.log_metric(self.run_id, key, value, timestamp, step)

    def log_artifact(self, local_path, artifact_path=None):
        self.client.log_artifact(self.run_id, local_path, artifact_path)

    def set_terminated(self):
        self.client.set_terminated(self.run_id)

    def move_mlruns(self):
        # runのコピー
        hydra_cwd = os.getcwd()
        exp_root, exp_id = os.path.split(hydra_cwd)
        src_mlrun_dir = os.path.join(hydra_cwd, "mlruns", "1")
        src_mlrun_path = [
            file_folder for file_folder in glob.glob(f"{src_mlrun_dir}/*")
            if os.path.isdir(file_folder)
        ][0]
        run_hash = os.path.basename(src_mlrun_path)
        dst_mlrun_path = os.path.join(hydra.utils.get_original_cwd(), "mlruns",
                                      "1", run_hash)
        shutil.copytree(src_mlrun_path, dst_mlrun_path)
        overwrite_meta_yaml(dst_mlrun_path, run_hash)
        # experimentのコピー
        dst_exp_path = os.path.join(hydra.utils.get_original_cwd(), "mlruns",
                                    "1")
        copy_exp_meta_yaml(src_mlrun_dir, dst_exp_path)
コード例 #3
0
ファイル: logger.py プロジェクト: xuman2019/ray
 def _init(self):
     from mlflow.tracking import MlflowClient
     client = MlflowClient()
     run = client.create_run(self.config.get("mlflow_experiment_id"))
     self._run_id = run.info.run_id
     for key, value in self.config.items():
         client.log_param(self._run_id, key, value)
     self.client = client
コード例 #4
0
ファイル: mlflow_utils.py プロジェクト: Carlos-UR/Info-HCVAE
 def __init__(self, config):
     mlflow.set_tracking_uri(config["mlflow_tracking_uri"])
     client = MlflowClient()
     run = client.create_run(config.get("mlflow_experiment_id"))
     self._run_id = run.info.run_id
     for key, value in config.items():
         try_mlflow_log(client.log_param, self._run_id, key, value)
     self.client = client
コード例 #5
0
ファイル: test_logging.py プロジェクト: vishalbelsare/skorch
 def test_fit_with_real_run_and_client(self, tmp_path, logger_cls,
                                       net_builder_cls):
     from mlflow.tracking import MlflowClient
     client = MlflowClient(tracking_uri=tmp_path.as_uri())
     experiment_name = 'foo'
     experiment_id = client.create_experiment(experiment_name)
     run = client.create_run(experiment_id)
     logger = logger_cls(run, client, create_artifact=False)
     net_builder_cls(callbacks=[logger], max_epochs=3)
     assert os.listdir(tmp_path)
コード例 #6
0
 def _init(self):
     logger_config = self.config.get("logger_config", {})
     from mlflow.tracking import MlflowClient
     client = MlflowClient(
         tracking_uri=logger_config.get("mlflow_tracking_uri"),
         registry_uri=logger_config.get("mlflow_registry_uri"))
     run = client.create_run(logger_config.get("mlflow_experiment_id"))
     self._run_id = run.info.run_id
     for key, value in self.config.items():
         client.log_param(self._run_id, key, value)
     self.client = client
コード例 #7
0
def preprocess_data(catalog,
                    ml_stage,
                    parameters='default',
                    config_path=None,
                    requires=None,
                    provides=None,
                    load_run_id=None,
                    enable_tracking=None,
                    experiment=None,
                    tracking_uri=None,
                    save_dir=None,
                    tags=None):
    cfg = load_configuration(catalog=catalog,
                             parameters=parameters,
                             config_path=config_path,
                             tracking_uri=tracking_uri,
                             save_dir=save_dir)

    run_id = None
    if enable_tracking:
        client = MlflowClient(cfg.tracking_uri)
        try:
            experiment_id = client.create_experiment(experiment)
        except:
            experiment_id = client.get_experiment_by_name(
                experiment).experiment_id
        mlrun = client.create_run(experiment_id=experiment_id)
        run_id = mlrun.info.run_id

    data_catalog = DataCatalog.from_config(catalog=cfg.catalog,
                                           data_dir=cfg.data_dir,
                                           load_versions=load_run_id,
                                           save_version=run_id,
                                           ml_stages=[ml_stage])
    parameters = cfg.parameters.preprocess
    rebind = cfg.catalog.rebind_names
    requires = dict() if requires is None else dict(requires)
    provides = dict() if provides is None else dict(provides)

    preprocess = PreprocessStep(data_catalog,
                                parameters,
                                ml_stage=ml_stage,
                                rebind=rebind,
                                requires=requires,
                                provides=provides,
                                run_id=run_id,
                                tracking_uri=cfg.tracking_uri,
                                experiment_name=experiment,
                                tags=tags)
    preprocess()
コード例 #8
0
ファイル: misc.py プロジェクト: CuriousCat-7/NAS
    def mlflow(mlflow_uri=None, tags={}, show_args={}):
        try:
            from mlflow.tracking import MlflowClient
        except ImportError:
            logger.error('ERROR: Failed importing mlflow')
            MlflowClient = NotImplemented

        if mlflow_uri is None:
            return None, None

        tags["git_branch"] = subprocess.check_output(
            ["git", "describe", "--all"]).decode("utf8").strip()
        tags["git_hash"] = subprocess.check_output(
            ["git", "describe", "--always"]).decode("utf8").strip()
        tags["timestamp"] = datetime.now().strftime("_%Y-%m-%d-%H-%M-%S")
        if "MY_HOSTIP" in os.environ.keys():
            tags["my_host_ip"] = os.environ["MY_HOSTIP"]
        if "MY_HOSTNAME" in os.environ.keys():
            tags["my_host_name"] = os.environ["MY_HOSTNAME"]

        for key, value in tags.items():
            show_args[f'tags-{key}'] = value

        if "exp_name" in tags:
            exp_name = tags["exp_name"]
        else:
            exp_name = "Ocr-PSEOnly"

        ml_client = MlflowClient(mlflow_uri)
        ml_exp = ml_client.get_experiment_by_name(exp_name)
        if ml_exp is None:
            ml_exp_id = ml_client.create_experiment(exp_name)
        else:
            ml_exp_id = ml_exp.experiment_id
        ml_run = ml_client.create_run(ml_exp_id)
        logger.info("ml_run: {}", ml_run)
        ml_run_id = ml_run.info.run_id
        for key, value in tags.items():
            logger.info(f'tag --- {key}: {value}')
            if ml_client is not None:
                ml_client.set_tag(ml_run_id, key, value)
        logger.info("------------")
        for key, value in show_args.items():
            logger.info(f'{key}: {value}')
            if ml_client is not None:
                ml_client.log_param(ml_run_id, key, value)

        return ml_client, ml_run_id
コード例 #9
0
class MlflowLogger:
    def __init__(self,
                 experiment_name,
                 tracking_uri=None,
                 registry_uri=None,
                 **kwargs):
        super().__init__(**kwargs)

        self.client = MlflowClient(tracking_uri, registry_uri)

        try:
            self.experiment_id = self.client.create_experiment(experiment_name)
        except MlflowException:
            self.experiment_id = self.client.get_experiment_by_name(
                experiment_name).experiment_id

        self.run_id = self.client.create_run(self.experiment_id).info.run_id

    def log_params_from_omegaconf_dict(self, params):
        for param_name, element in params.items():
            self._explore_recursive(param_name, element)

    def _explore_recursive(self, parent_name, element):
        if isinstance(element, DictConfig):
            for k, v in element.items():
                if isinstance(v, DictConfig) or isinstance(v, ListConfig):
                    self._explore_recursive(f'{parent_name}.{k}', v)
                else:
                    self.client.log_param(self.run_id, f'{parent_name}.{k}', v)
        elif isinstance(element, ListConfig):
            for i, v in enumerate(element):
                self.client.log_param(self.run_id, f'{parent_name}.{i}', v)

    def log_metric(self, key, value):
        self.client.log_metric(self.run_id, key, value)

    def log_torch_model(self, model, model_name):
        with mlflow.start_run(self.run_id):
            pytorch.log_model(model, model_name)

    def log_param(self, key, value):
        self.client.log_param(self.run_id, key, value)

    def log_artifact(self, local_path):
        self.client.log_artifact(self.run_id, local_path)

    def set_terminated(self):
        self.client.set_terminated(self.run_id)
コード例 #10
0
def manage_and_run():
    client = MlflowClient()
    experiments = client.list_experiments(
    )  # returns a list of mlflow.entities.Experiment
    for experiment in experiments:
        print('[manage_and_run] experiment: ', experiment.experiment_id,
              experiment.name)

        experiment_id = experiment.experiment_id
        run_infos = client.list_run_infos(experiment_id)
        print('[manage_and_run] run_infos: ', run_infos)

    run = client.create_run(
        experiments[0].experiment_id)  # returns mlflow.entities.Run
    client.log_param(run.info.run_id, "hello", "world")
    client.set_terminated(run.info.run_id)
コード例 #11
0
class MlflowWriter():
    def __init__(self, experiment_name, task_name, **kwargs):
        self.client = MlflowClient(**kwargs)
        self.experiment_name = experiment_name
        self.task_name = task_name
        try:
            self.experiment_id = self.client.create_experiment(experiment_name)
        except:
            self.experiment_id = self.client.get_experiment_by_name(experiment_name).experiment_id

        self.run_id = self.client.create_run(self.experiment_id).info.run_id

        mlflow.set_tag(MLFLOW_RUN_NAME, task_name)

        print(f"Exp ID:{self.experiment_id}, Run ID:{self.run_id}")

    def log_params_from_omegaconf_dict(self, params):
        for param_name, element in params.items():
            self._explore_recursive(param_name, element)

    def _explore_recursive(self, parent_name, element):
        if isinstance(element, DictConfig):
            for k, v in element.items():
                if isinstance(v, DictConfig) or isinstance(v, ListConfig):
                    self._explore_recursive(f'{parent_name}.{k}', v)
                else:
                    self.client.log_param(self.run_id, f'{parent_name}.{k}', v)
        elif isinstance(element, ListConfig):
            for i, v in enumerate(element):
                self.client.log_param(self.run_id, f'{parent_name}.{i}', v)

    def log_torch_model(self, model):
        with mlflow.start_run(self.run_id):
            pytorch.log_model(model, 'models')

    def log_param(self, key, value):
        self.client.log_param(self.run_id, key, value)

    def log_metric(self, key, value):
        self.client.log_metric(self.run_id, key, value)

    def log_artifact(self, local_path):
        self.client.log_artifact(self.run_id, local_path)

    def set_terminated(self):
        self.client.set_terminated(self.run_id)
コード例 #12
0
    def _init(self):
        from mlflow.tracking import MlflowClient
        uri = osp.join(osp.dirname(self.logdir), 'mlruns')
        # print(uri)
        # import ipdb
        # ipdb.set_trace()
        # raise RuntimeError
        client = MlflowClient(tracking_uri=uri)
        experiments = [e.name for e in client.list_experiments()]
        exp_name = self.config.get("mlflow_experiment", "test")
        if exp_name in experiments:
            experiment_id = client.get_experiment_by_name(exp_name)
        else:
            experiment_id = client.create_experiment(exp_name)
        run = client.create_run(experiment_id.experiment_id,
                                tags={'mlflow.runName': self.trial.trial_id})
        self._run_id = run.info.run_id

        self.client = client
        self._log_hparams()
コード例 #13
0
    def testLogArtifact(self):
        with open('artifacts/foo.txt',
                  'rb') as f1, open('artifacts/image.png', 'rb') as f2, open(
                      'artifacts/animation.gif', 'rb') as f3:
            client = MlflowClient(tracking_uri="http://localhost:5000")
            try:
                experiment_id = client.create_experiment("foo")
            except RestException as e:
                experiment = client.get_experiment_by_name("foo")
                experiment_id = experiment.experiment_id
            run = client.create_run(experiment_id)
            run_id = run.info.run_id
            print(experiment_id + ":" + run_id)

            files = {'file1': f1, 'file2': f2, 'file3': f3}
            data = {'run_id': run_id}
            r = requests.post('http://localhost:5001/log_artifact',
                              files=files,
                              data=data)
            print(r.text)
コード例 #14
0
def run_experiment(h_params: Dict[str, Any], params_file: str, mlflow_client: MlflowClient,
                   experiment_id: int, device: Optional[torch.device] = None,
                   tags: Optional[Dict[str, str]] = None, verbose: bool = False):
    params = get_overriden_params(h_params, params_file=params_file)
    # Creating run under the specified experiment
    run: Run = mlflow_client.create_run(experiment_id=experiment_id, tags=tags)
    log_params(mlflow_client, run, h_params)
    status = None
    try:
        with tempfile.TemporaryDirectory() as train_dir:
            train(train_dir=train_dir,
                  config=params,
                  force=True,
                  metric_logger=partial(log_metrics, mlflow_client, run),
                  device=device,
                  verbose=verbose)
            mlflow_client.log_artifacts(run.info.run_uuid, train_dir)
    except Exception as e:
        print(f"Run failed! Exception occurred: {e}.")
        status = 'FAILED'
    mlflow_client.set_terminated(run.info.run_uuid, status=status)
コード例 #15
0
ファイル: writer.py プロジェクト: sarrrrry/ImageClassifier
class MlflowWriter(BaseMlWriter):
    def __init__(self, log_dir, experiment_name):
        super().__init__()

        mlflow_dir = log_dir / "mlflow" / "mlruns"
        self.client = MlflowClient(tracking_uri=str(mlflow_dir))
        try:
            self.experiment_id = self.client.create_experiment(experiment_name)
        except:
            self.experiment_id = self.client.get_experiment_by_name(
                experiment_name).experiment_id

        self.run_id = self.client.create_run(self.experiment_id).info.run_id

    def __del__(self):
        self.client.set_terminated(self.run_id)

    def log_params(self, params: Dict):
        # from flatten_dict import flatten
        flatten_params = flatten(params)

        for key, value in flatten_params.items():
            self.log_param(key, value)

    def log_artifact(self, local_path: Path):
        if local_path.exists():
            self.client.log_artifact(self.run_id, str(local_path))
        else:
            logger.info(f"NOT Exists: {local_path}")

    def log_torch_model(self, model):
        with mlflow.start_run(self.run_id):
            mlflow.pytorch.log_model(model, 'models')

    def log_param(self, key, value):
        self.client.log_param(self.run_id, key, value)

    def log_metric(self, key, value):
        self.client.log_metric(self.run_id, key, value)
コード例 #16
0
class Tracker():
    def __init__(self, experiment: str, tracking_uri: str):
        self.experiment_name = experiment
        self.experiment_id = None
        self.run_id = None
        self.client = MlflowClient(tracking_uri=tracking_uri)

    def log_param(self, key: str, value: Any):
        self.client.log_param(self.run_id, key, value)

    def log_params(self, params: Dict):
        for k, v in params.items():
            self.log_param(k, v)

    def log_artifacts(self, artifacts: List[str]):
        for artifact in artifacts:
            self.client.log_artifact(self.run_id, artifact)

    def log_metrics(self, metrics: Dict[str, Any]):
        for k, v in metrics.items():
            if isinstance(v, list):
                for i in range(len(v)):
                    self.client.log_metric(self.run_id, k, v[i], step=i)
            else:
                self.client.log_metric(self.run_id, k, v)

    def start_run(self):
        try:
            self.experiment_id = self.client.create_experiment(
                self.experiment_name)
        except Exception:
            self.experiment_id = self.client.get_experiment_by_name(
                self.experiment_name).experiment_id
        run = self.client.create_run(self.experiment_id)
        self.run_id = run.info.run_id

    def end_run(self):
        self.client.set_terminated(self.run_id)
コード例 #17
0
class RemoteTracking:
    def __init__(self, tracking_uri=None, registry_uri=None):
        self.server = MlflowClient(tracking_uri, registry_uri)

    def get_experiment_id(self, name, artifact_location=None):
        experiment = self.server.get_experiment_by_name(name)
        if experiment:
            experiment_id = experiment.experiment_id
            return experiment_id
        else:
            print("Experiment not exist")
            print("Creating new experiment on tracking")
            experiment_id = self.server.create_experiment(
                name, artifact_location)
            return experiment_id

    def get_run_id(self, experiment_id):
        run = self.server.create_run(experiment_id)
        run_id = run.info.run_id
        return run_id

    def log_params(self, run_id, params):
        for key, value in params.items():
            self.server.log_param(run_id, key, value)
        print("Parameters successful logged")

    def set_tags(self, run_id, params):
        for key, value in params.items():
            self.server.set_tag(run_id, key, value)
        print("Tags successful logged")

    def log_metrics(self, run_id, params):
        for key, value in params.items():
            self.server.log_metric(run_id, key, value)
        print("Metrics successful logged")

    def log_artifacts(self, run_id, local_dir, artifact_path=None):
        self.server.log_artifacts(run_id, local_dir, artifact_path)
コード例 #18
0
class MLFlowLogger(LightningLoggerBase):
    """
    Log using `MLflow <https://mlflow.org>`_.

    Install it with pip:

    .. code-block:: bash

        pip install mlflow

    .. code-block:: python

        from pytorch_lightning import Trainer
        from pytorch_lightning.loggers import MLFlowLogger

        mlf_logger = MLFlowLogger(experiment_name="default", tracking_uri="file:./ml-runs")
        trainer = Trainer(logger=mlf_logger)

    Use the logger anywhere in your :class:`~pytorch_lightning.core.lightning.LightningModule` as follows:

    .. code-block:: python

        from pytorch_lightning import LightningModule


        class LitModel(LightningModule):
            def training_step(self, batch, batch_idx):
                # example
                self.logger.experiment.whatever_ml_flow_supports(...)

            def any_lightning_module_function_or_hook(self):
                self.logger.experiment.whatever_ml_flow_supports(...)

    Args:
        experiment_name: The name of the experiment
        run_name: Name of the new run. The `run_name` is internally stored as a ``mlflow.runName`` tag.
            If the ``mlflow.runName`` tag has already been set in `tags`, the value is overridden by the `run_name`.
        tracking_uri: Address of local or remote tracking server.
            If not provided, defaults to `MLFLOW_TRACKING_URI` environment variable if set, otherwise it falls
            back to `file:<save_dir>`.
        tags: A dictionary tags for the experiment.
        save_dir: A path to a local directory where the MLflow runs get saved.
            Defaults to `./mlflow` if `tracking_uri` is not provided.
            Has no effect if `tracking_uri` is provided.
        prefix: A string to put at the beginning of metric keys.
        artifact_location: The location to store run artifacts. If not provided, the server picks an appropriate
            default.

    Raises:
        ImportError:
            If required MLFlow package is not installed on the device.
    """

    LOGGER_JOIN_CHAR = "-"

    def __init__(
        self,
        experiment_name: str = "default",
        run_name: Optional[str] = None,
        tracking_uri: Optional[str] = os.getenv("MLFLOW_TRACKING_URI"),
        tags: Optional[Dict[str, Any]] = None,
        save_dir: Optional[str] = "./mlruns",
        prefix: str = "",
        artifact_location: Optional[str] = None,
    ):
        if mlflow is None:
            raise ImportError(
                "You want to use `mlflow` logger which is not installed yet, install it with `pip install mlflow`."
            )
        super().__init__()
        if not tracking_uri:
            tracking_uri = f"{LOCAL_FILE_URI_PREFIX}{save_dir}"

        self._experiment_name = experiment_name
        self._experiment_id = None
        self._tracking_uri = tracking_uri
        self._run_name = run_name
        self._run_id = None
        self.tags = tags
        self._prefix = prefix
        self._artifact_location = artifact_location

        self._mlflow_client = MlflowClient(tracking_uri)

    @property
    @rank_zero_experiment
    def experiment(self) -> MlflowClient:
        r"""
        Actual MLflow object. To use MLflow features in your
        :class:`~pytorch_lightning.core.lightning.LightningModule` do the following.

        Example::

            self.logger.experiment.some_mlflow_function()

        """
        if self._experiment_id is None:
            expt = self._mlflow_client.get_experiment_by_name(
                self._experiment_name)
            if expt is not None:
                self._experiment_id = expt.experiment_id
            else:
                log.warning(
                    f"Experiment with name {self._experiment_name} not found. Creating it."
                )
                self._experiment_id = self._mlflow_client.create_experiment(
                    name=self._experiment_name,
                    artifact_location=self._artifact_location)

        if self._run_id is None:
            if self._run_name is not None:
                self.tags = self.tags or {}
                if MLFLOW_RUN_NAME in self.tags:
                    log.warning(
                        f"The tag {MLFLOW_RUN_NAME} is found in tags. The value will be overridden by {self._run_name}."
                    )
                self.tags[MLFLOW_RUN_NAME] = self._run_name
            run = self._mlflow_client.create_run(
                experiment_id=self._experiment_id,
                tags=resolve_tags(self.tags))
            self._run_id = run.info.run_id
        return self._mlflow_client

    @property
    def run_id(self) -> str:
        """
        Create the experiment if it does not exist to get the run id.

        Returns:
            The run id.
        """
        _ = self.experiment
        return self._run_id

    @property
    def experiment_id(self) -> str:
        """
        Create the experiment if it does not exist to get the experiment id.

        Returns:
            The experiment id.
        """
        _ = self.experiment
        return self._experiment_id

    @rank_zero_only
    def log_hyperparams(self, params: Union[Dict[str, Any],
                                            Namespace]) -> None:
        params = self._convert_params(params)
        params = self._flatten_dict(params)
        for k, v in params.items():
            if len(str(v)) > 250:
                rank_zero_warn(
                    f"Mlflow only allows parameters with up to 250 characters. Discard {k}={v}",
                    RuntimeWarning)
                continue

            self.experiment.log_param(self.run_id, k, v)

    @rank_zero_only
    def log_metrics(self,
                    metrics: Dict[str, float],
                    step: Optional[int] = None) -> None:
        assert rank_zero_only.rank == 0, "experiment tried to log from global_rank != 0"

        metrics = self._add_prefix(metrics)

        timestamp_ms = int(time() * 1000)
        for k, v in metrics.items():
            if isinstance(v, str):
                log.warning(f"Discarding metric with string value {k}={v}.")
                continue

            new_k = re.sub("[^a-zA-Z0-9_/. -]+", "", k)
            if k != new_k:
                rank_zero_warn(
                    "MLFlow only allows '_', '/', '.' and ' ' special characters in metric name."
                    f" Replacing {k} with {new_k}.",
                    RuntimeWarning,
                )
                k = new_k

            self.experiment.log_metric(self.run_id, k, v, timestamp_ms, step)

    @rank_zero_only
    def finalize(self, status: str = "FINISHED") -> None:
        super().finalize(status)
        status = "FINISHED" if status == "success" else status
        if self.experiment.get_run(self.run_id):
            self.experiment.set_terminated(self.run_id, status)

    @property
    def save_dir(self) -> Optional[str]:
        """
        The root file directory in which MLflow experiments are saved.

        Return:
            Local path to the root experiment directory if the tracking uri is local.
            Otherwhise returns `None`.
        """
        if self._tracking_uri.startswith(LOCAL_FILE_URI_PREFIX):
            return self._tracking_uri.lstrip(LOCAL_FILE_URI_PREFIX)

    @property
    def name(self) -> str:
        """
        Get the experiment id.

        Returns:
            The experiment id.
        """
        return self.experiment_id

    @property
    def version(self) -> str:
        """
        Get the run id.

        Returns:
            The run id.
        """
        return self.run_id
コード例 #19
0
ファイル: mlflow.py プロジェクト: xqk/ray
class MLFlowLoggerCallback(LoggerCallback):
    """MLFlow Logger to automatically log Tune results and config to MLFlow.

    MLFlow (https://mlflow.org) Tracking is an open source library for
    recording and querying experiments. This Ray Tune ``LoggerCallback``
    sends information (config parameters, training results & metrics,
    and artifacts) to MLFlow for automatic experiment tracking.

    Args:
        tracking_uri (str): The tracking URI for where to manage experiments
            and runs. This can either be a local file path or a remote server.
            This arg gets passed directly to mlflow.tracking.MlflowClient
            initialization. When using Tune in a multi-node setting, make sure
            to set this to a remote server and not a local file path.
        registry_uri (str): The registry URI that gets passed directly to
            mlflow.tracking.MlflowClient initialization.
        experiment_name (str): The experiment name to use for this Tune run.
            If None is passed in here, the Logger will automatically then
            check the MLFLOW_EXPERIMENT_NAME and then the MLFLOW_EXPERIMENT_ID
            environment variables to determine the experiment name.
            If the experiment with the name already exists with MlFlow,
            it will be reused. If not, a new experiment will be created with
            that name.
        save_artifact (bool): If set to True, automatically save the entire
            contents of the Tune local_dir as an artifact to the
            corresponding run in MlFlow.

    Example:

    .. code-block:: python

        from ray.tune.integration.mlflow import MLFlowLoggerCallback
        tune.run(
            train_fn,
            config={
                # define search space here
                "parameter_1": tune.choice([1, 2, 3]),
                "parameter_2": tune.choice([4, 5, 6]),
            },
            callbacks=[MLFlowLoggerCallback(
                experiment_name="experiment1",
                save_artifact=True)])

    """
    def __init__(self,
                 tracking_uri: Optional[str] = None,
                 registry_uri: Optional[str] = None,
                 experiment_name: Optional[str] = None,
                 save_artifact: bool = False):

        mlflow = _import_mlflow()
        if mlflow is None:
            raise RuntimeError("MLFlow has not been installed. Please `pip "
                               "install mlflow` to use the MLFlowLogger.")

        from mlflow.tracking import MlflowClient
        self.client = MlflowClient(tracking_uri=tracking_uri,
                                   registry_uri=registry_uri)

        if experiment_name is None:
            # If no name is passed in, then check env vars.
            # First check if experiment_name env var is set.
            experiment_name = os.environ.get("MLFLOW_EXPERIMENT_NAME")

        if experiment_name is not None:
            # First check if experiment with name exists.
            experiment = self.client.get_experiment_by_name(experiment_name)
            if experiment is not None:
                # If it already exists then get the id.
                experiment_id = experiment.experiment_id
            else:
                # If it does not exist, create the experiment.
                experiment_id = self.client.create_experiment(
                    name=experiment_name)
        else:
            # No experiment_name is passed in and name env var is not set.
            # Now check the experiment id env var.
            experiment_id = os.environ.get("MLFLOW_EXPERIMENT_ID")
            # Confirm that an experiment with this id exists.
            if experiment_id is None or self.client.get_experiment(
                    experiment_id) is None:
                raise ValueError("No experiment_name passed, "
                                 "MLFLOW_EXPERIMENT_NAME env var is not "
                                 "set, and MLFLOW_EXPERIMENT_ID either "
                                 "is not set or does not exist. Please "
                                 "set one of these to use the "
                                 "MLFlowLoggerCallback.")

        # At this point, experiment_id should be set.
        self.experiment_id = experiment_id
        self.save_artifact = save_artifact

        self._trial_runs = {}

    def log_trial_start(self, trial: "Trial"):
        # Create run if not already exists.
        if trial not in self._trial_runs:
            run = self.client.create_run(experiment_id=self.experiment_id,
                                         tags={"trial_name": str(trial)})
            self._trial_runs[trial] = run.info.run_id

        run_id = self._trial_runs[trial]

        # Log the config parameters.
        config = trial.config

        for key, value in config.items():
            self.client.log_param(run_id=run_id, key=key, value=value)

    def log_trial_result(self, iteration: int, trial: "Trial", result: Dict):
        run_id = self._trial_runs[trial]
        for key, value in result.items():
            try:
                value = float(value)
            except (ValueError, TypeError):
                logger.debug("Cannot log key {} with value {} since the "
                             "value cannot be converted to float.".format(
                                 key, value))
                continue
            self.client.log_metric(run_id=run_id,
                                   key=key,
                                   value=value,
                                   step=iteration)

    def log_trial_end(self, trial: "Trial", failed: bool = False):
        run_id = self._trial_runs[trial]

        # Log the artifact if set_artifact is set to True.
        if self.save_artifact:
            self.client.log_artifacts(run_id, local_dir=trial.logdir)

        # Stop the run once trial finishes.
        status = "FINISHED" if not failed else "FAILED"
        self.client.set_terminated(run_id=run_id, status=status)
コード例 #20
0
# make a prediction for an example of an out-of-sample observation
knn.predict([[6, 3, 4, 2]])

# COMMAND ----------

mlflow.create_experiment('/Users/[email protected]/Model 1/Experiments/test3')
test = test + str(1)

mlflow.end_run()
with mlflow.start_run() as run:
  mlflow.log_param("param1", 5)
  mlflow.log_metric("foo",2,step=1)
  mlflow.log_metric("foo",4,step=2)
  mlflow.log_metric("foo",6,step=3)
  
with open("output.parquet", "w") as f:
  f.write("Hello world!")
mlflow.log_artifact("output.parquet")


# COMMAND ----------

from mlflow.tracking import MlflowClient
client = MlflowClient()
experiments = client.list_exeperiments()
run = client.create_run(experiments[0].experiment_id)
client.log_param(run.infor.run_id, "hello", "world")
client.set_terminated(run.info.run_i)


コード例 #21
0
import warnings

from mlflow.tracking import MlflowClient

if __name__ == '__main__':

    warnings.filterwarnings("ignore")

    # Create a run under the default experiment (whose ID is "0").
    client = MlflowClient()
    expriment_id = "0"
    run = client.create_run(expriment_id)
    run_id = run.info.run_id
    print("run_id: {}; lifecycle_stage: {}".format(run_id,
                                                   run.info.lifecycle_stage))
    print("--")
    client.delete_run(run_id)
    del_run = client.get_run(run_id)
    print("run_id: {}; lifecycle_stage: {}".format(
        run_id, del_run.info.lifecycle_stage))
コード例 #22
0
class MLFlowLogger(LightningLoggerBase):
    """
    Log using `MLflow <https://mlflow.org>`_. Install it with pip:

    .. code-block:: bash

        pip install mlflow

    Example:
        >>> from pytorch_lightning import Trainer
        >>> from pytorch_lightning.loggers import MLFlowLogger
        >>> mlf_logger = MLFlowLogger(
        ...     experiment_name="default",
        ...     tracking_uri="file:./ml-runs"
        ... )
        >>> trainer = Trainer(logger=mlf_logger)

    Use the logger anywhere in you :class:`~pytorch_lightning.core.lightning.LightningModule` as follows:

    >>> from pytorch_lightning import LightningModule
    >>> class LitModel(LightningModule):
    ...     def training_step(self, batch, batch_idx):
    ...         # example
    ...         self.logger.experiment.whatever_ml_flow_supports(...)
    ...
    ...     def any_lightning_module_function_or_hook(self):
    ...         self.logger.experiment.whatever_ml_flow_supports(...)

    Args:
        experiment_name: The name of the experiment
        tracking_uri: Address of local or remote tracking server.
            If not provided, defaults to `file:<save_dir>`.
        tags: A dictionary tags for the experiment.
        save_dir: A path to a local directory where the MLflow runs get saved.
            Defaults to `./mlflow` if `tracking_uri` is not provided.
            Has no effect if `tracking_uri` is provided.

    """
    def __init__(self,
                 experiment_name: str = 'default',
                 tracking_uri: Optional[str] = None,
                 tags: Optional[Dict[str, Any]] = None,
                 save_dir: Optional[str] = './mlruns'):

        if not _MLFLOW_AVAILABLE:
            raise ImportError(
                'You want to use `mlflow` logger which is not installed yet,'
                ' install it with `pip install mlflow`.')
        super().__init__()
        if not tracking_uri:
            tracking_uri = f'{LOCAL_FILE_URI_PREFIX}{save_dir}'

        self._experiment_name = experiment_name
        self._experiment_id = None
        self._tracking_uri = tracking_uri
        self._run_id = None
        self.tags = tags
        self._mlflow_client = MlflowClient(tracking_uri)

    @property
    @rank_zero_experiment
    def experiment(self) -> MlflowClient:
        r"""
        Actual MLflow object. To use MLflow features in your
        :class:`~pytorch_lightning.core.lightning.LightningModule` do the following.

        Example::

            self.logger.experiment.some_mlflow_function()

        """
        expt = self._mlflow_client.get_experiment_by_name(
            self._experiment_name)

        if expt:
            self._experiment_id = expt.experiment_id
        else:
            log.warning(
                f'Experiment with name {self._experiment_name} not found. Creating it.'
            )
            self._experiment_id = self._mlflow_client.create_experiment(
                name=self._experiment_name)

        if not self._run_id:
            run = self._mlflow_client.create_run(
                experiment_id=self._experiment_id, tags=self.tags)
            self._run_id = run.info.run_id
        return self._mlflow_client

    @property
    def run_id(self):
        # create the experiment if it does not exist to get the run id
        _ = self.experiment
        return self._run_id

    @property
    def experiment_id(self):
        # create the experiment if it does not exist to get the experiment id
        _ = self.experiment
        return self._experiment_id

    @rank_zero_only
    def log_hyperparams(self, params: Union[Dict[str, Any],
                                            Namespace]) -> None:
        params = self._convert_params(params)
        params = self._flatten_dict(params)
        for k, v in params.items():
            self.experiment.log_param(self.run_id, k, v)

    @rank_zero_only
    def log_metrics(self,
                    metrics: Dict[str, float],
                    step: Optional[int] = None) -> None:
        assert rank_zero_only.rank == 0, 'experiment tried to log from global_rank != 0'

        timestamp_ms = int(time() * 1000)
        for k, v in metrics.items():
            if isinstance(v, str):
                log.warning(f'Discarding metric with string value {k}={v}.')
                continue
            self.experiment.log_metric(self.run_id, k, v, timestamp_ms, step)

    @rank_zero_only
    def finalize(self, status: str = 'FINISHED') -> None:
        super().finalize(status)
        status = 'FINISHED' if status == 'success' else status
        if self.experiment.get_run(self.run_id):
            self.experiment.set_terminated(self.run_id, status)

    @property
    def save_dir(self) -> Optional[str]:
        """
        The root file directory in which MLflow experiments are saved.

        Return:
            Local path to the root experiment directory if the tracking uri is local.
            Otherwhise returns `None`.
        """
        if self._tracking_uri.startswith(LOCAL_FILE_URI_PREFIX):
            return self._tracking_uri.lstrip(LOCAL_FILE_URI_PREFIX)

    @property
    def name(self) -> str:
        return self.experiment_id

    @property
    def version(self) -> str:
        return self.run_id
コード例 #23
0
def create_run():
    """
    Create a new run and return the run_id
    """
    return MlflowClient.create_run()
コード例 #24
0
ファイル: runner.py プロジェクト: nvanva/LexSubGen
class Runner:
    @staticmethod
    def import_additional_modules(additional_modules):
        # Import additional modules
        logger.info("Importing additional modules...")
        if additional_modules is not None:
            if not isinstance(additional_modules, list):
                additional_modules = [additional_modules]
            for additional_module in additional_modules:
                import_submodules(additional_module)

    def __init__(self,
                 run_dir: str,
                 force: bool = False,
                 auto_create_subdir: bool = False):
        """
        Class that handles command line interaction with the LexSubGen framework.
        Different methods of this class are related to different scenarios of framework usage.
        E.g. evaluate method performs substitute generator evaluation on the dataset specified
        in the configuration.

        Args:
            run_dir: path to the directory where to store experiment data.
            force: whether to rewrite data in the existing directory.
            auto_create_subdir: if true a subdirectory will be created automatically
                and its name will be the current date and time
        """
        self.run_dir = Path(run_dir)
        if auto_create_subdir and not force:
            time_str = datetime.now().isoformat().split('.')[0]
            self.run_dir = self.run_dir / f"{time_str}"
        self.force = force
        self.git_tags = get_git_tags()
        self.lib_versions = get_lib_versions()

        # Create run directory
        logger.info(f"Creating run directory {self.run_dir}...")
        create_run_dir(self.run_dir, force=self.force)
        dump_json(self.run_dir / "lib_versions.json", self.lib_versions)

        self.mlflow_dir = str(ENTRY_DIR / "mlruns")
        mlflow.set_tracking_uri(self.mlflow_dir)
        self.mlflow_client = MlflowClient(tracking_uri=self.mlflow_dir)

    def evaluate(
        self,
        config_path: str = None,
        config: Optional[Dict] = None,
        additional_modules: Optional[List[str]] = None,
        experiment_name: Optional[str] = None,
        run_name: Optional[str] = None,
    ) -> NoReturn:
        """
        Evaluates task defined by configuration file.

        Args:
            config_path: path to a configuration file.
            config: configuration of a task.
            additional_modules: path to directories with modules that should be registered in global Registry.
            experiment_name: results of the run will be added to 'experiment_name' experiment in MLflow.
            run_name: this run will be marked as 'run_name' in MLflow.
        """
        # Instantiate objects from config
        task, config = build_from_config_path(config_path, config)

        self.import_additional_modules(additional_modules)

        # Create experiment with given name or get already existing
        if experiment_name is None:
            experiment_name = config["class_name"]
        experiment_id = get_experiment_id(self.mlflow_client, experiment_name)

        # Add Run name in MLFlow tags
        tags = copy(self.git_tags)
        if config_path is not None and run_name is None:
            run_name = Path(config_path).stem
        if run_name is not None:
            tags[MLFLOW_RUN_NAME] = run_name

        # Create Run entity for tracking
        run_entity = self.mlflow_client.create_run(experiment_id=experiment_id,
                                                   tags=tags)
        saved_params = dict()
        generator_params = config["substitute_generator"]
        log_params(self.mlflow_client, run_entity, generator_params,
                   saved_params)

        dump_json(self.run_dir / "config.json", config)

        logger.info("Evaluating...")
        metrics = task.evaluate(run_dir=self.run_dir)
        metrics = metrics["mean_metrics"]
        log_metrics(self.mlflow_client, run_entity, metrics)
        self.mlflow_client.log_artifacts(run_entity.info.run_uuid,
                                         local_dir=self.run_dir)
        logger.info("Evaluation performed.")

    def hyperparam_search(self, config_path: str,
                          experiment_name: str) -> NoReturn:
        """
        Run hyperparameters enumeration defined by configuration file.

        Args:
            config_path: path to a configuration file.
            experiment_name: results of the run will be added to 'experiment_name' experiment in MLflow.
        """
        config = read_config(config_path, verbose=True)

        run_name = Path(config_path).stem

        dump_json(self.run_dir / "config.json", config)

        # Set MLFlow settings
        experiment_id = get_experiment_id(self.mlflow_client, experiment_name)

        parameter_grid = Grid(config)
        for run_idx, (grid_dot,
                      param_config) in tqdm(enumerate(parameter_grid)):
            tags = copy(self.git_tags)
            tags[MLFLOW_RUN_NAME] = f"{run_name}_run_{run_idx}"
            run_entity = self.mlflow_client.create_run(
                experiment_id=experiment_id, tags=tags)
            params_to_log = {
                k: v
                for k, v in zip(parameter_grid.param_names, grid_dot)
            }
            log_params(self.mlflow_client, run_entity, params_to_log, dict())
            task = build_from_params(param_config)
            metrics = task.evaluate(run_dir=self.run_dir)
            metrics = metrics.get("mean_metrics", None)
            log_metrics(self.mlflow_client, run_entity, metrics)

    def augment(self,
                dataset_name: str,
                config_path: str = None,
                config: Optional[Dict] = None):
        """
        Performs dataset augmentation.

        Args:
            dataset_name: name of the dataset to augment
            config_path: path to a configuration file.
            config: configuration of a task
        """
        augmenter, config = build_from_config_path(config_path, config)

        dump_json(self.run_dir / "config.json", config)

        logger.info(f"Augmenting {dataset_name}...")
        augmented_dataset = augmenter.augment_dataset(
            dataset_name=dataset_name)
        augmented_dataset.to_csv(self.run_dir / "augmented_dataset.tsv",
                                 sep="\t",
                                 index=False)
        logger.info(
            f"Augmentation performed. Results was saved in {self.run_dir}")
コード例 #25
0
ファイル: mlflow.py プロジェクト: yulkang/pytorch-lightning
class MLFlowLogger(LightningLoggerBase):
    """
    Log using `MLflow <https://mlflow.org>`_. Install it with pip:

    .. code-block:: bash

        pip install mlflow

    Example:
        >>> from pytorch_lightning import Trainer
        >>> from pytorch_lightning.loggers import MLFlowLogger
        >>> mlf_logger = MLFlowLogger(
        ...     experiment_name="default",
        ...     tracking_uri="file:./ml-runs"
        ... )
        >>> trainer = Trainer(logger=mlf_logger)

    Use the logger anywhere in you :class:`~pytorch_lightning.core.lightning.LightningModule` as follows:

    >>> from pytorch_lightning import LightningModule
    >>> class LitModel(LightningModule):
    ...     def training_step(self, batch, batch_idx):
    ...         # example
    ...         self.logger.experiment.whatever_ml_flow_supports(...)
    ...
    ...     def any_lightning_module_function_or_hook(self):
    ...         self.logger.experiment.whatever_ml_flow_supports(...)

    Args:
        experiment_name: The name of the experiment
        tracking_uri: Address of local or remote tracking server.
            If not provided, defaults to the service set by ``mlflow.tracking.set_tracking_uri``.
        tags: A dictionary tags for the experiment.

    """
    def __init__(self,
                 experiment_name: str = 'default',
                 tracking_uri: Optional[str] = None,
                 tags: Optional[Dict[str, Any]] = None,
                 save_dir: Optional[str] = None):
        super().__init__()
        if not tracking_uri and save_dir:
            tracking_uri = f'file:{os.sep * 2}{save_dir}'
        self._mlflow_client = MlflowClient(tracking_uri)
        self.experiment_name = experiment_name
        self._run_id = None
        self.tags = tags

    @property
    def experiment(self) -> MlflowClient:
        r"""
        Actual MLflow object. To use mlflow features in your
        :class:`~pytorch_lightning.core.lightning.LightningModule` do the following.

        Example::

            self.logger.experiment.some_mlflow_function()

        """
        return self._mlflow_client

    @property
    def run_id(self):
        if self._run_id is not None:
            return self._run_id

        expt = self._mlflow_client.get_experiment_by_name(self.experiment_name)

        if expt:
            self._expt_id = expt.experiment_id
        else:
            log.warning(
                f'Experiment with name {self.experiment_name} not found. Creating it.'
            )
            self._expt_id = self._mlflow_client.create_experiment(
                name=self.experiment_name)

        run = self._mlflow_client.create_run(experiment_id=self._expt_id,
                                             tags=self.tags)
        self._run_id = run.info.run_id
        return self._run_id

    @rank_zero_only
    def log_hyperparams(self, params: Union[Dict[str, Any],
                                            Namespace]) -> None:
        params = self._convert_params(params)
        params = self._flatten_dict(params)
        for k, v in params.items():
            self.experiment.log_param(self.run_id, k, v)

    @rank_zero_only
    def log_metrics(self,
                    metrics: Dict[str, float],
                    step: Optional[int] = None) -> None:
        timestamp_ms = int(time() * 1000)
        for k, v in metrics.items():
            if isinstance(v, str):
                log.warning(f'Discarding metric with string value {k}={v}.')
                continue
            self.experiment.log_metric(self.run_id, k, v, timestamp_ms, step)

    @rank_zero_only
    def finalize(self, status: str = 'FINISHED') -> None:
        super().finalize(status)
        if status == 'success':
            status = 'FINISHED'
        self.experiment.set_terminated(self.run_id, status)

    @property
    def name(self) -> str:
        return self.experiment_name

    @property
    def version(self) -> str:
        return self._run_id
コード例 #26
0
ファイル: log_artifacts.py プロジェクト: dmatrix/mlflow-tests
import json

from mlflow.tracking import MlflowClient

if __name__ == "__main__":

    # Create some artifacts data to preserve
    features = "rooms, zipcode, median_price, school_rating, transport"
    data = {"state": "TX", "Available": 25, "Type": "Detached"}

    # Create couple of artifact files under the directory "data"
    os.makedirs("data", exist_ok=True)
    with open("data/data.json", 'w', encoding='utf-8') as f:
        json.dump(data, f, indent=2)
    with open("data/features.txt", 'w') as f:
        f.write(features)

    # Create a run under the default experiment (whose id is "0"), and log
    # all files in "data" to root artifact_uri/states
    client = MlflowClient()
    expermient_id = "0"
    run = client.create_run(expermient_id)
    client.log_artifacts(run.info.run_id, "data", artifact_path="states")
    artifacts = client.list_artifacts(run.info.run_id)
    for artifact in artifacts:
        print("artifact: {}".format(artifact.path))
        print("is_dir: {}".format(artifact.is_dir))
    client.set_terminated(run.info.run_id)


コード例 #27
0
from mlflow.tracking import MlflowClient
import mlflow
import os

client = MlflowClient()
experiments = client.list_experiments(
)  # returns a list of mlflow.entities.Experiment
remote_server_uri = "http://10.77.36.45:5000/"  # set to your server URI
mlflow.set_tracking_uri(remote_server_uri)

run = client.create_run(
    experiments[0].experiment_id)  # returns mlflow.entities.Run
client.log_param(run.info.run_id, "hello", "world")
client.set_terminated(run.info.run_id)
client.set_tag(run.info.run_id, "36", "new module")
コード例 #28
0
ファイル: delete_tag.py プロジェクト: dmatrix/mlflow-tests
import warnings

import mlflow
from mlflow.tracking import MlflowClient

if __name__ == "__main__":

    warnings.filterwarnings("ignore")
    print(mlflow.__version__)

    def print_run_info(run):
        print("run_id: {}".format(run.info.run_id))
        print("Tags: {}".format(run.data.tags))

    # Create a run with some tags under the default experiment (whose id is "0").
    client = MlflowClient()
    tags = {"t1": 1, "t2": 2}
    experiment_id = "0"
    run = client.create_run(experiment_id, tags=tags)
    print_run_info(run)
    print("--")

    # Delete tag and fetch new info
    client.delete_tag(run.info.run_id, "t1")
    run = client.get_run(run.info.run_id)
    print_run_info(run)
コード例 #29
0
client = MlflowClient()  # client
exps = client.list_experiments()  # get all experiments

# COMMAND ----------

exps

# COMMAND ----------

exp = [
    s for s in exps if "/Users/{}/exps/MLFlowExp".format(
        spark.conf.get("com.databricks.demo.username")) in s.name
][0]  # get only the exp we want
exp_id = exp.experiment_id  # save exp id to variable
artifact_location = exp.artifact_location  # artifact location for storing
run = client.create_run(exp_id)  # create the run
run_id = run.info.run_id  # get the run id

# COMMAND ----------

# start and mlflow run
mlflow.start_run(run_id)

# COMMAND ----------

df = (spark.read.format("csv").option("inferSchema", "True").option(
    "header",
    "True").load("/databricks-datasets/bikeSharing/data-001/day.csv"))
# split data
train_df, test_df = df.randomSplit([0.7, 0.3])
コード例 #30
0
class MLFlowLogger(LightningLoggerBase):
    """
    Log using `MLflow <https://mlflow.org>`_.

    Install it with pip:

    .. code-block:: bash

        pip install mlflow

    .. code-block:: python

        from pytorch_lightning import Trainer
        from pytorch_lightning.loggers import MLFlowLogger
        mlf_logger = MLFlowLogger(
            experiment_name="default",
            tracking_uri="file:./ml-runs"
        )
        trainer = Trainer(logger=mlf_logger)

    Use the logger anywhere in your :class:`~pytorch_lightning.core.lightning.LightningModule` as follows:

    .. code-block:: python

        from pytorch_lightning import LightningModule
        class LitModel(LightningModule):
            def training_step(self, batch, batch_idx):
                # example
                self.logger.experiment.whatever_ml_flow_supports(...)

            def any_lightning_module_function_or_hook(self):
                self.logger.experiment.whatever_ml_flow_supports(...)

    Args:
        experiment_name: The name of the experiment
        tracking_uri: Address of local or remote tracking server.
            If not provided, defaults to `file:<save_dir>`.
        tags: A dictionary tags for the experiment.
        save_dir: A path to a local directory where the MLflow runs get saved.
            Defaults to `./mlflow` if `tracking_uri` is not provided.
            Has no effect if `tracking_uri` is provided.
        prefix: A string to put at the beginning of metric keys.

    """

    LOGGER_JOIN_CHAR = '-'

    def __init__(
        self,
        experiment_name: str = 'default',
        tracking_uri: Optional[str] = None,
        tags: Optional[Dict[str, Any]] = None,
        save_dir: Optional[str] = './mlruns',
        prefix: str = '',
    ):
        if mlflow is None:
            raise ImportError(
                'You want to use `mlflow` logger which is not installed yet,'
                ' install it with `pip install mlflow`.')
        super().__init__()
        if not tracking_uri:
            tracking_uri = f'{LOCAL_FILE_URI_PREFIX}{save_dir}'

        self._experiment_name = experiment_name
        self._experiment_id = None
        self._tracking_uri = tracking_uri
        self._run_id = None
        self.tags = tags
        self._prefix = prefix
        self._mlflow_client = MlflowClient(tracking_uri)

    @property
    @rank_zero_experiment
    def experiment(self) -> MlflowClient:
        r"""
        Actual MLflow object. To use MLflow features in your
        :class:`~pytorch_lightning.core.lightning.LightningModule` do the following.

        Example::

            self.logger.experiment.some_mlflow_function()

        """
        if self._experiment_id is None:
            expt = self._mlflow_client.get_experiment_by_name(
                self._experiment_name)
            if expt is not None:
                self._experiment_id = expt.experiment_id
            else:
                log.warning(
                    f'Experiment with name {self._experiment_name} not found. Creating it.'
                )
                self._experiment_id = self._mlflow_client.create_experiment(
                    name=self._experiment_name)

        if self._run_id is None:
            run = self._mlflow_client.create_run(
                experiment_id=self._experiment_id, tags=self.tags)
            self._run_id = run.info.run_id
        return self._mlflow_client

    @property
    def run_id(self):
        # create the experiment if it does not exist to get the run id
        _ = self.experiment
        return self._run_id

    @property
    def experiment_id(self):
        # create the experiment if it does not exist to get the experiment id
        _ = self.experiment
        return self._experiment_id

    @rank_zero_only
    def log_hyperparams(self, params: Union[Dict[str, Any],
                                            Namespace]) -> None:
        params = self._convert_params(params)
        params = self._flatten_dict(params)
        for k, v in params.items():
            if len(str(v)) > 250:
                rank_zero_warn(
                    f"Mlflow only allows parameters with up to 250 characters. Discard {k}={v}",
                    RuntimeWarning)
                continue

            self.experiment.log_param(self.run_id, k, v)

    @rank_zero_only
    def log_metrics(self,
                    metrics: Dict[str, float],
                    step: Optional[int] = None) -> None:
        assert rank_zero_only.rank == 0, 'experiment tried to log from global_rank != 0'

        metrics = self._add_prefix(metrics)

        timestamp_ms = int(time() * 1000)
        for k, v in metrics.items():
            if isinstance(v, str):
                log.warning(f'Discarding metric with string value {k}={v}.')
                continue

            new_k = re.sub("[^a-zA-Z0-9_/. -]+", "", k)
            if k != new_k:
                rank_zero_warn(
                    "MLFlow only allows '_', '/', '.' and ' ' special characters in metric name."
                    f" Replacing {k} with {new_k}.", RuntimeWarning)
                k = new_k

            self.experiment.log_metric(self.run_id, k, v, timestamp_ms, step)

    @rank_zero_only
    def finalize(self, status: str = 'FINISHED') -> None:
        super().finalize(status)
        status = 'FINISHED' if status == 'success' else status
        if self.experiment.get_run(self.run_id):
            self.experiment.set_terminated(self.run_id, status)

    @property
    def save_dir(self) -> Optional[str]:
        """
        The root file directory in which MLflow experiments are saved.

        Return:
            Local path to the root experiment directory if the tracking uri is local.
            Otherwhise returns `None`.
        """
        if self._tracking_uri.startswith(LOCAL_FILE_URI_PREFIX):
            return self._tracking_uri.lstrip(LOCAL_FILE_URI_PREFIX)

    @property
    def name(self) -> str:
        return self.experiment_id

    @property
    def version(self) -> str:
        return self.run_id