예제 #3
파일: recorder.py 프로젝트: microsoft/qlib
class MLflowRecorder(Recorder):
    Use mlflow to implement a Recorder.

    Due to the fact that mlflow will only log artifact from a file or directory, we decide to
    use file manager to help maintain the objects in the project.
    def __init__(self, experiment_id, uri, name=None, mlflow_run=None):
        super(MLflowRecorder, self).__init__(experiment_id, name)
        self._uri = uri
        self._artifact_uri = None
        self.client = mlflow.tracking.MlflowClient(tracking_uri=self._uri)
        # construct from mlflow run
        if mlflow_run is not None:
            assert isinstance(mlflow_run, mlflow.entities.run.Run
                              ), "Please input with a MLflow Run object."
            self.name = mlflow_run.data.tags["mlflow.runName"]
            self.id = mlflow_run.info.run_id
            self.status = mlflow_run.info.status
            self.start_time = (datetime.fromtimestamp(
                float(mlflow_run.info.start_time) /
                1000.0).strftime("%Y-%m-%d %H:%M:%S")
                               if mlflow_run.info.start_time is not None else
            self.end_time = (datetime.fromtimestamp(
                float(mlflow_run.info.end_time) /
                1000.0).strftime("%Y-%m-%d %H:%M:%S")
                             if mlflow_run.info.end_time is not None else None)
        self.async_log = None

    def __repr__(self):
        name = self.__class__.__name__
        space_length = len(name) + 1
        return "{name}(info={info},\n{space}uri={uri},\n{space}artifact_uri={artifact_uri},\n{space}client={client})".format(
            space=" " * space_length,

    def __hash__(self) -> int:
        return hash(self.info["id"])

    def __eq__(self, o: object) -> bool:
        if isinstance(o, MLflowRecorder):
            return self.info["id"] == o.info["id"]
        return False

    def uri(self):
        return self._uri

    def artifact_uri(self):
        return self._artifact_uri

    def get_local_dir(self):
        This function will return the directory path of this recorder.
        if self.artifact_uri is not None:
            local_dir_path = Path(self.artifact_uri.lstrip("file:")) / ".."
            local_dir_path = str(local_dir_path.resolve())
            if os.path.isdir(local_dir_path):
                return local_dir_path
                raise RuntimeError(
                    "This recorder is not saved in the local file system.")

            raise Exception(
                "Please make sure the recorder has been created and started properly before getting artifact uri."

    def start_run(self):
        # set the tracking uri
        # start the run
        run = mlflow.start_run(self.id, self.experiment_id, self.name)
        # save the run id and artifact_uri
        self.id = run.info.run_id
        self._artifact_uri = run.info.artifact_uri
        self.start_time = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
        self.status = Recorder.STATUS_R
            f"Recorder {self.id} starts running under Experiment {self.experiment_id} ..."

        # NOTE: making logging async.
        # - This may cause delay when uploading results
        # - The logging time may not be accurate
        self.async_log = AsyncCaller()
        return run

    def end_run(self, status: str = Recorder.STATUS_S):
        assert status in [
        ], f"The status type {status} is not supported."
        self.end_time = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
        if self.status != Recorder.STATUS_S:
            self.status = status
        if self.async_log is not None:
            with TimeInspector.logt("waiting `async_log`"):
        self.async_log = None

    def save_objects(self, local_path=None, artifact_path=None, **kwargs):
        assert self.uri is not None, "Please start the experiment and recorder first before using recorder directly."
        if local_path is not None:
            path = Path(local_path)
            if path.is_dir():
                self.client.log_artifacts(self.id, local_path, artifact_path)
                self.client.log_artifact(self.id, local_path, artifact_path)
            temp_dir = Path(tempfile.mkdtemp()).resolve()
            for name, data in kwargs.items():
                path = temp_dir / name
                Serializable.general_dump(data, path)
                self.client.log_artifact(self.id, temp_dir / name,

    def load_object(self, name, unpickler=pickle.Unpickler):
        Load object such as prediction file or model checkpoint in mlflow.

            name (str): the object name

            unpickler: Supporting using custom unpickler

            LoadObjectError: if raise some exceptions when load the object

            object: the saved object in mlflow.
        assert self.uri is not None, "Please start the experiment and recorder first before using recorder directly."

            path = self.client.download_artifacts(self.id, name)
            with Path(path).open("rb") as f:
                data = unpickler(f).load()
            ar = self.client._tracking_client._get_artifact_repo(self.id)
            if isinstance(ar, AzureBlobArtifactRepository):
                # for saving disk space
                # For safety, only remove redundant file for specific ArtifactRepository
            return data
        except Exception as e:
            raise LoadObjectError(str(e)) from e

    def log_params(self, **kwargs):
        for name, data in kwargs.items():
            self.client.log_param(self.id, name, data)

    def log_metrics(self, step=None, **kwargs):
        for name, data in kwargs.items():
            self.client.log_metric(self.id, name, data, step=step)

    def set_tags(self, **kwargs):
        for name, data in kwargs.items():
            self.client.set_tag(self.id, name, data)

    def delete_tags(self, *keys):
        for key in keys:
            self.client.delete_tag(self.id, key)

    def get_artifact_uri(self):
        if self.artifact_uri is not None:
            return self.artifact_uri
            raise Exception(
                "Please make sure the recorder has been created and started properly before getting artifact uri."

    def list_artifacts(self, artifact_path=None):
        assert self.uri is not None, "Please start the experiment and recorder first before using recorder directly."
        artifacts = self.client.list_artifacts(self.id, artifact_path)
        return [art.path for art in artifacts]

    def list_metrics(self):
        run = self.client.get_run(self.id)
        return run.data.metrics

    def list_params(self):
        run = self.client.get_run(self.id)
        return run.data.params

    def list_tags(self):
        run = self.client.get_run(self.id)
        return run.data.tags