Exemple #1
0
 def _update(self, msg, method="PATCH", update_mask=None):
     self._refresh_cache(
     )  # to have `self._msg.registered_model_id` for URL
     if update_mask:
         url = "{}://{}/api/v1/registry/registered_models/{}/model_versions/{}/full_body".format(
             self._conn.scheme,
             self._conn.socket,
             self._msg.registered_model_id,
             self.id,
         )
         # proto converter for update_mask is missing
         data = {
             "model_version": _utils.proto_to_json(msg, False),
             "update_mask": update_mask,
         }
         response = _utils.make_request(method, url, self._conn, json=data)
     else:
         response = self._conn.make_proto_request(
             method,
             "/api/v1/registry/registered_models/{}/model_versions/{}".
             format(self._msg.registered_model_id, self.id),
             body=msg,
             include_default=False,
         )
     self._conn.must_proto_response(
         response, _RegistryService.SetModelVersion.Response)
     self._clear_cache()
    def _get_url_for_artifact(self,
                              dataset_component_path,
                              method,
                              part_num=0):
        """
        Obtains a URL to use for accessing stored artifacts.

        Parameters
        ----------
        dataset_component_path : str
            Filepath in dataset component blob.
        method : {'GET', 'PUT'}
            HTTP method to request for the generated URL.
        part_num : int, optional
            If using Multipart Upload, number of part to be uploaded.

        Returns
        -------
        response_msg : `_DatasetVersionService.GetUrlForDatasetBlobVersioned.Response`
            Backend response.

        """
        if method.upper() not in ("GET", "PUT"):
            raise ValueError("`method` must be one of {'GET', 'PUT'}")

        Message = _DatasetVersionService.GetUrlForDatasetBlobVersioned
        msg = Message(
            path_dataset_component_blob_path=dataset_component_path,
            method=method,
            part_number=part_num,
        )
        data = _utils.proto_to_json(msg)
        endpoint = "{}://{}/api/v1/modeldb/dataset-version/dataset/{}/datasetVersion/{}/getUrlForDatasetBlobVersioned".format(
            self._conn.scheme,
            self._conn.socket,
            self.dataset_id,
            self.id,
        )
        response = _utils.make_request("POST", endpoint, self._conn, json=data)
        _utils.raise_for_http_error(response)

        response_msg = _utils.json_to_proto(response.json(), Message.Response)

        url = response_msg.url
        # accommodate port-forwarded NFS store
        if 'https://localhost' in url[:20]:
            url = 'http' + url[5:]
        if 'localhost%3a' in url[:20]:
            url = url.replace('localhost%3a', 'localhost:')
        if 'localhost%3A' in url[:20]:
            url = url.replace('localhost%3A', 'localhost:')
        response_msg.url = url

        return response_msg
Exemple #3
0
    def _get_url_for_artifact(self, key, method, artifact_type=0, part_num=0):
        if method.upper() not in ("GET", "PUT"):
            raise ValueError("`method` must be one of {'GET', 'PUT'}")

        Message = _RegistryService.GetUrlForArtifact
        msg = Message(
            model_version_id=self.id,
            key=key,
            method=method,
            artifact_type=artifact_type,
            part_number=part_num,
        )
        data = _utils.proto_to_json(msg)
        endpoint = "{}://{}/api/v1/registry/model_versions/{}/getUrlForArtifact".format(
            self._conn.scheme, self._conn.socket, self.id)
        response = _utils.make_request("POST", endpoint, self._conn, json=data)
        _utils.raise_for_http_error(response)
        return _utils.json_to_proto(response.json(), Message.Response)
Exemple #4
0
    def get_code(self):
        """
        Gets the code version.

        Returns
        -------
        dict or zipfile.ZipFile
            Either:
                - a dictionary containing Git snapshot information with at most the following items:
                    - **filepaths** (*list of str*)
                    - **repo_url** (*str*) – Remote repository URL
                    - **commit_hash** (*str*) – Commit hash
                    - **is_dirty** (*bool*)
                - a `ZipFile <https://docs.python.org/3/library/zipfile.html#zipfile.ZipFile>`_
                  containing Python source code files

        """
        # TODO: remove this circular dependency
        from ._project import Project
        from ._experiment import Experiment
        from ._experimentrun import ExperimentRun
        if isinstance(self, Project):  # TODO: not this
            Message = self._service.GetProjectCodeVersion
            endpoint = "getProjectCodeVersion"
        elif isinstance(self, Experiment):
            Message = self._service.GetExperimentCodeVersion
            endpoint = "getExperimentCodeVersion"
        elif isinstance(self, ExperimentRun):
            Message = self._service.GetExperimentRunCodeVersion
            endpoint = "getExperimentRunCodeVersion"
        msg = Message(id=self.id)
        data = _utils.proto_to_json(msg)
        response = _utils.make_request("GET",
                                       self._request_url.format(endpoint),
                                       self._conn,
                                       params=data)
        _utils.raise_for_http_error(response)

        response_msg = _utils.json_to_proto(_utils.body_to_json(response),
                                            Message.Response)
        code_ver_msg = response_msg.code_version
        which_code = code_ver_msg.WhichOneof('code')
        if which_code == 'git_snapshot':
            git_snapshot_msg = code_ver_msg.git_snapshot
            git_snapshot = {}
            if git_snapshot_msg.filepaths:
                git_snapshot['filepaths'] = git_snapshot_msg.filepaths
            if git_snapshot_msg.repo:
                git_snapshot['repo_url'] = git_snapshot_msg.repo
            if git_snapshot_msg.hash:
                git_snapshot['commit_hash'] = git_snapshot_msg.hash
            if git_snapshot_msg.is_dirty != _CommonCommonService.TernaryEnum.UNKNOWN:
                git_snapshot[
                    'is_dirty'] = git_snapshot_msg.is_dirty == _CommonCommonService.TernaryEnum.TRUE
            return git_snapshot
        elif which_code == 'code_archive':
            # download artifact from artifact store
            # pylint: disable=no-member
            # this method should only be called on ExperimentRun, which does have _get_url_for_artifact()
            url = self._get_url_for_artifact(
                "verta_code_archive", "GET",
                code_ver_msg.code_archive.artifact_type).url

            response = _utils.make_request("GET", url, self._conn)
            _utils.raise_for_http_error(response)

            code_archive = six.BytesIO(response.content)
            return zipfile.ZipFile(
                code_archive, 'r')  # TODO: return a util class instead, maybe
        else:
            raise RuntimeError("unable find code in response")
Exemple #5
0
    def log_code(self,
                 exec_path=None,
                 repo_url=None,
                 commit_hash=None,
                 overwrite=False,
                 is_dirty=None,
                 autocapture=True):
        """
        Logs the code version.

        A code version is either information about a Git snapshot or a bundle of Python source code files.

        `repo_url` and `commit_hash` can only be set if `use_git` was set to ``True`` in the Client.

        Parameters
        ----------
        exec_path : str, optional
            Filepath to the executable Python script or Jupyter notebook. If no filepath is provided,
            the Client will make its best effort to find the currently running script/notebook file.
        repo_url : str, optional
            URL for a remote Git repository containing `commit_hash`. If no URL is provided, the Client
            will make its best effort to find it.
        commit_hash : str, optional
            Git commit hash associated with this code version. If no hash is provided, the Client will
            make its best effort to find it.
        overwrite : bool, default False
            Whether to allow overwriting a code version.
        is_dirty : bool, optional
            Whether git status is dirty relative to `commit_hash`. If not provided, the Client will
            make its best effort to find it.
        autocapture : bool, default True
            Whether to enable the automatic capturing behavior of parameters above in git mode.

        Examples
        --------
        With ``Client(use_git=True)`` (default):

            Log Git snapshot information, plus the location of the currently executing notebook/script
            relative to the repository root:

            .. code-block:: python

                run.log_code()
                run.get_code()
                # {'exec_path': 'comparison/outcomes/classification.ipynb',
                #  'repo_url': '[email protected]:VertaAI/experiments.git',
                #  'commit_hash': 'f99abcfae6c3ce6d22597f95ad6ef260d31527a6',
                #  'is_dirty': False}

            Log Git snapshot information, plus the location of a specific source code file relative
            to the repository root:

            .. code-block:: python

                run.log_code("../trainer/training_pipeline.py")
                run.get_code()
                # {'exec_path': 'comparison/trainer/training_pipeline.py',
                #  'repo_url': '[email protected]:VertaAI/experiments.git',
                #  'commit_hash': 'f99abcfae6c3ce6d22597f95ad6ef260d31527a6',
                #  'is_dirty': False}

        With ``Client(use_git=False)``:

            Find and upload the currently executing notebook/script:

            .. code-block:: python

                run.log_code()
                zip_file = run.get_code()
                zip_file.printdir()
                # File Name                          Modified             Size
                # classification.ipynb        2019-07-10 17:18:24        10287

            Upload a specific source code file:

            .. code-block:: python

                run.log_code("../trainer/training_pipeline.py")
                zip_file = run.get_code()
                zip_file.printdir()
                # File Name                          Modified             Size
                # training_pipeline.py        2019-05-31 10:34:44          964

        """
        if self._conf.use_git and autocapture:
            # verify Git
            try:
                repo_root_dir = _git_utils.get_git_repo_root_dir()
            except OSError:
                # don't halt execution
                print(
                    "unable to locate git repository; you may be in an unsupported environment;\n"
                    "    consider using `autocapture=False` to manually enter values"
                )
                return
                # six.raise_from(OSError("failed to locate git repository; please check your working directory"),
                #                None)
            print("Git repository successfully located at {}".format(
                repo_root_dir))
        if not self._conf.use_git:
            if repo_url is not None or commit_hash is not None:
                raise ValueError(
                    "`repo_url` and `commit_hash` can only be set if `use_git` was set to True in the Client"
                )
            if not autocapture:  # user passed `False`
                raise ValueError(
                    "`autocapture` is only applicable if `use_git` was set to True in the Client"
                )

        if autocapture:
            if exec_path is None:
                # find dynamically
                try:
                    exec_path = _utils.get_notebook_filepath()
                except (ImportError, OSError):  # notebook not found
                    try:
                        exec_path = _utils.get_script_filepath()
                    except OSError:  # script not found
                        print("unable to find code file; skipping")
            else:
                exec_path = os.path.expanduser(exec_path)
                if not os.path.isfile(exec_path):
                    raise ValueError(
                        "`exec_path` \"{}\" must be a valid filepath".format(
                            exec_path))

        # TODO: remove this circular dependency
        from ._project import Project
        from ._experiment import Experiment
        from ._experimentrun import ExperimentRun
        if isinstance(self, Project):  # TODO: not this
            Message = self._service.LogProjectCodeVersion
            endpoint = "logProjectCodeVersion"
        elif isinstance(self, Experiment):
            Message = self._service.LogExperimentCodeVersion
            endpoint = "logExperimentCodeVersion"
        elif isinstance(self, ExperimentRun):
            Message = self._service.LogExperimentRunCodeVersion
            endpoint = "logExperimentRunCodeVersion"
        msg = Message(id=self.id)

        if overwrite:
            if isinstance(self, ExperimentRun):
                msg.overwrite = True
            else:
                raise ValueError(
                    "`overwrite=True` is currently only supported for ExperimentRun"
                )

        if self._conf.use_git:
            if autocapture:
                try:
                    # adjust `exec_path` to be relative to repo root
                    exec_path = os.path.relpath(
                        exec_path, _git_utils.get_git_repo_root_dir())
                except OSError as e:
                    print(
                        "{}; logging absolute path to file instead".format(e))
                    exec_path = os.path.abspath(exec_path)
            if exec_path:
                msg.code_version.git_snapshot.filepaths.append(exec_path)

            from verta.code import _git  # avoid Python 2 top-level circular import
            code_ver = _git.Git(
                repo_url=repo_url,
                commit_hash=commit_hash,
                is_dirty=is_dirty,
                autocapture=autocapture,
            )
            msg.code_version.git_snapshot.repo = code_ver.repo_url or ""
            msg.code_version.git_snapshot.hash = code_ver.commit_hash or ""
            if not autocapture and is_dirty is None:
                msg.code_version.git_snapshot.is_dirty = _CommonCommonService.TernaryEnum.UNKNOWN
            elif code_ver.is_dirty:
                msg.code_version.git_snapshot.is_dirty = _CommonCommonService.TernaryEnum.TRUE
            else:
                msg.code_version.git_snapshot.is_dirty = _CommonCommonService.TernaryEnum.FALSE
        else:  # log code as Artifact
            if exec_path is None:
                # don't halt execution
                print(
                    "unable to find code file; you may be in an unsupported environment"
                )
                return
                # raise RuntimeError("unable to find code file; you may be in an unsupported environment")

            # write ZIP archive
            zipstream = six.BytesIO()
            with zipfile.ZipFile(zipstream, 'w') as zipf:
                filename = os.path.basename(exec_path)
                if exec_path.endswith(".ipynb"):
                    try:
                        saved_notebook = _utils.save_notebook(exec_path)
                    except:  # failed to save
                        print("unable to automatically save Notebook;"
                              " logging latest checkpoint from disk")
                        zipf.write(exec_path, filename)
                    else:
                        zipf.writestr(
                            _artifact_utils.global_read_zipinfo(filename),
                            six.ensure_binary(saved_notebook.read()),
                        )
                else:
                    zipf.write(exec_path, filename)
            zipstream.seek(0)

            key = 'code'
            extension = 'zip'

            artifact_hash = _artifact_utils.calc_sha256(zipstream)
            basename = key + os.extsep + extension
            artifact_path = os.path.join(artifact_hash, basename)

            msg.code_version.code_archive.path = artifact_path
            msg.code_version.code_archive.path_only = False
            msg.code_version.code_archive.artifact_type = _CommonCommonService.ArtifactTypeEnum.CODE
            msg.code_version.code_archive.filename_extension = extension
        # TODO: check if we actually have any loggable information
        msg.code_version.date_logged = _utils.now()

        data = _utils.proto_to_json(msg)
        response = _utils.make_request("POST",
                                       self._request_url.format(endpoint),
                                       self._conn,
                                       json=data)
        if not response.ok:
            if response.status_code == 409:
                raise ValueError("a code version has already been logged")
            else:
                _utils.raise_for_http_error(response)

        if msg.code_version.WhichOneof("code") == 'code_archive':
            # upload artifact to artifact store
            # pylint: disable=no-member
            # this method should only be called on ExperimentRun, which does have _get_url_for_artifact()
            url = self._get_url_for_artifact(
                "verta_code_archive", "PUT",
                msg.code_version.code_archive.artifact_type).url

            response = _utils.make_request("PUT",
                                           url,
                                           self._conn,
                                           data=zipstream)
            _utils.raise_for_http_error(response)
Exemple #6
0
    def _upload_artifact(self,
                         key,
                         file_handle,
                         artifact_type,
                         part_size=_artifact_utils._64MB):
        file_handle.seek(0)

        # check if multipart upload ok
        url_for_artifact = self._get_url_for_artifact(key,
                                                      "PUT",
                                                      artifact_type,
                                                      part_num=1)

        print("uploading {} to Registry".format(key))
        if url_for_artifact.multipart_upload_ok:
            # TODO: parallelize this
            file_parts = iter(lambda: file_handle.read(part_size), b"")
            for part_num, file_part in enumerate(file_parts, start=1):
                print("uploading part {}".format(part_num), end="\r")

                # get presigned URL
                url = self._get_url_for_artifact(key,
                                                 "PUT",
                                                 artifact_type,
                                                 part_num=part_num).url

                # wrap file part into bytestream to avoid OverflowError
                #     Passing a bytestring >2 GB (num bytes > max val of int32) directly to
                #     ``requests`` will overwhelm CPython's SSL lib when it tries to sign the
                #     payload. But passing a buffered bytestream instead of the raw bytestring
                #     indicates to ``requests`` that it should perform a streaming upload via
                #     HTTP/1.1 chunked transfer encoding and avoid this issue.
                #     https://github.com/psf/requests/issues/2717
                part_stream = six.BytesIO(file_part)

                # upload part
                response = _utils.make_request("PUT",
                                               url,
                                               self._conn,
                                               data=part_stream)
                _utils.raise_for_http_error(response)

                # commit part
                url = "{}://{}/api/v1/registry/model_versions/{}/commitArtifactPart".format(
                    self._conn.scheme, self._conn.socket, self.id)
                msg = _RegistryService.CommitArtifactPart(
                    model_version_id=self.id, key=key)
                msg.artifact_part.part_number = part_num
                msg.artifact_part.etag = response.headers["ETag"]
                data = _utils.proto_to_json(msg)
                response = _utils.make_request("POST",
                                               url,
                                               self._conn,
                                               json=data)
                _utils.raise_for_http_error(response)
            print()

            # complete upload
            url = "{}://{}/api/v1/registry/model_versions/{}/commitMultipartArtifact".format(
                self._conn.scheme, self._conn.socket, self.id)
            msg = _RegistryService.CommitMultipartArtifact(
                model_version_id=self.id, key=key)
            data = _utils.proto_to_json(msg)
            response = _utils.make_request("POST", url, self._conn, json=data)
            _utils.raise_for_http_error(response)
        else:
            # upload full artifact
            if url_for_artifact.fields:
                # if fields were returned by backend, make a POST request and supply them as form fields
                response = _utils.make_request(
                    "POST",
                    url_for_artifact.url,
                    self._conn,
                    # requests uses the `files` parameter for sending multipart/form-data POSTs.
                    #     https://stackoverflow.com/a/12385661/8651995
                    # the file contents must be the final form field
                    #     https://docs.aws.amazon.com/AmazonS3/latest/dev/HTTPPOSTForms.html#HTTPPOSTFormFields
                    files=list(url_for_artifact.fields.items()) +
                    [("file", file_handle)],
                )
            else:
                response = _utils.make_request("PUT",
                                               url_for_artifact.url,
                                               self._conn,
                                               data=file_handle)
            _utils.raise_for_http_error(response)

        print("upload complete")
    def _upload_artifact(self,
                         dataset_component_path,
                         file_handle,
                         part_size=_artifact_utils._64MB):
        """
        Uploads `file_handle` to ModelDB artifact store.

        Parameters
        ----------
        dataset_component_path : str
            Filepath in dataset component blob.
        file_handle : file-like
            Artifact to be uploaded.
        part_size : int, default 64 MB
            If using multipart upload, number of bytes to upload per part.

        """
        file_handle.seek(0)

        # check if multipart upload ok
        url_for_artifact = self._get_url_for_artifact(dataset_component_path,
                                                      "PUT",
                                                      part_num=1)

        print("uploading {} to ModelDB".format(dataset_component_path))
        if url_for_artifact.multipart_upload_ok:
            # TODO: parallelize this
            file_parts = iter(lambda: file_handle.read(part_size), b'')
            for part_num, file_part in enumerate(file_parts, start=1):
                print("uploading part {}".format(part_num), end='\r')

                # get presigned URL
                url = self._get_url_for_artifact(dataset_component_path,
                                                 "PUT",
                                                 part_num=part_num).url

                # wrap file part into bytestream to avoid OverflowError
                #     Passing a bytestring >2 GB (num bytes > max val of int32) directly to
                #     ``requests`` will overwhelm CPython's SSL lib when it tries to sign the
                #     payload. But passing a buffered bytestream instead of the raw bytestring
                #     indicates to ``requests`` that it should perform a streaming upload via
                #     HTTP/1.1 chunked transfer encoding and avoid this issue.
                #     https://github.com/psf/requests/issues/2717
                part_stream = six.BytesIO(file_part)

                # upload part
                response = _utils.make_request("PUT",
                                               url,
                                               self._conn,
                                               data=part_stream)
                _utils.raise_for_http_error(response)

                # commit part
                url = "{}://{}/api/v1/modeldb/dataset-version/commitVersionedDatasetBlobArtifactPart".format(
                    self._conn.scheme,
                    self._conn.socket,
                )
                msg = _DatasetVersionService.CommitVersionedDatasetBlobArtifactPart(
                    dataset_version_id=self.id,
                    path_dataset_component_blob_path=dataset_component_path,
                )
                msg.artifact_part.part_number = part_num
                msg.artifact_part.etag = response.headers['ETag']
                data = _utils.proto_to_json(msg)
                response = _utils.make_request("POST",
                                               url,
                                               self._conn,
                                               json=data)
                _utils.raise_for_http_error(response)
            print()

            # complete upload
            url = "{}://{}/api/v1/modeldb/dataset-version/commitMultipartVersionedDatasetBlobArtifact".format(
                self._conn.scheme,
                self._conn.socket,
            )
            msg = _DatasetVersionService.CommitMultipartVersionedDatasetBlobArtifact(
                dataset_version_id=self.id,
                path_dataset_component_blob_path=dataset_component_path,
            )
            data = _utils.proto_to_json(msg)
            response = _utils.make_request("POST", url, self._conn, json=data)
            _utils.raise_for_http_error(response)
        else:
            # upload full artifact
            if url_for_artifact.fields:
                # if fields were returned by backend, make a POST request and supply them as form fields
                response = _utils.make_request(
                    "POST",
                    url_for_artifact.url,
                    self._conn,
                    # requests uses the `files` parameter for sending multipart/form-data POSTs.
                    #     https://stackoverflow.com/a/12385661/8651995
                    # the file contents must be the final form field
                    #     https://docs.aws.amazon.com/AmazonS3/latest/dev/HTTPPOSTForms.html#HTTPPOSTFormFields
                    files=list(url_for_artifact.fields.items()) +
                    [('file', file_handle)],
                )
            else:
                response = _utils.make_request("PUT",
                                               url_for_artifact.url,
                                               self._conn,
                                               data=file_handle)
            _utils.raise_for_http_error(response)

        print("upload complete")