Beispiel #1
0
    def log_setup_script(self, script, overwrite=False):
        """
        Associate a model deployment setup script with this Experiment Run.

        .. versionadded:: 0.13.8

        Parameters
        ----------
        script : str
            String composed of valid Python code for executing setup steps at the beginning of model
            deployment. An on-disk file can be passed in using ``open("path/to/file.py", 'r').read()``.
        overwrite : bool, default False
            Whether to allow overwriting an existing setup script.

        Raises
        ------
        SyntaxError
            If `script` contains invalid Python.

        """
        # validate `script`'s syntax
        try:
            ast.parse(script)
        except SyntaxError as e:
            # clarify that the syntax error comes from `script`, and propagate details
            reason = e.args[0]
            line_no = e.args[1][1]
            line = script.splitlines()[line_no - 1]
            six.raise_from(
                SyntaxError("{} in provided script on line {}:\n{}".format(
                    reason, line_no, line)), e)

        # convert into bytes for upload
        script = six.ensure_binary(script)

        # convert to file-like for `_log_artifact()`
        script = six.BytesIO(script)

        self.log_artifact(
            "setup_script",
            script,
            overwrite,
            "py",
        )
Beispiel #2
0
    def get_artifact(self, key):
        """
        Gets the artifact with name `key` from this Model Version.

        If the artifact was originally logged as just a filesystem path, that path will be returned.
        Otherwise, bytes representing the artifact object will be returned.

        Parameters
        ----------
        key : str
            Name of the artifact.

        Returns
        -------
        str or object or bytes
            Path of the artifact, the artifact object, or a bytestream representing the
            artifact.

        """
        artifact = self._get_artifact(
            key, _CommonCommonService.ArtifactTypeEnum.BLOB)
        artifact_stream = six.BytesIO(artifact)

        torch = importer.maybe_dependency("torch")
        if torch is not None:
            try:
                obj = torch.load(artifact_stream)
            except:  # not something torch can deserialize
                artifact_stream.seek(0)
            else:
                artifact_stream.close()
                return obj

        try:
            obj = pickle.load(artifact_stream)
        except:  # not something pickle can deserialize
            artifact_stream.seek(0)
        else:
            artifact_stream.close()
            return obj

        return artifact_stream
    def _custom_modules_as_artifact(self, paths=None):
        if isinstance(paths, six.string_types):
            paths = [paths]

        # If we include a path that is actually a module, then we _must_ add its parent to the
        # adjusted sys.path in the end so that we can re-import with the same name.
        forced_local_sys_paths = []
        if paths is not None:
            new_paths = []
            for p in paths:
                abspath = os.path.abspath(os.path.expanduser(p))
                if os.path.exists(abspath):
                    new_paths.append(abspath)
                else:
                    try:
                        mod = importlib.import_module(p)
                        new_paths.extend(mod.__path__)
                        forced_local_sys_paths.extend(map(os.path.dirname, mod.__path__))
                    except ImportError:
                        raise ValueError("custom module {} does not correspond to an existing folder or module".format(p))

            paths = new_paths

        forced_local_sys_paths = sorted(list(set(forced_local_sys_paths)))

        # collect local sys paths
        local_sys_paths = copy.copy(sys.path)
        ## replace empty first element with cwd
        ##     https://docs.python.org/3/library/sys.html#sys.path
        if local_sys_paths[0] == "":
            local_sys_paths[0] = os.getcwd()
        ## convert to absolute paths
        local_sys_paths = list(map(os.path.abspath, local_sys_paths))
        ## remove paths that don't exist
        local_sys_paths = list(filter(os.path.exists, local_sys_paths))
        ## remove .ipython
        local_sys_paths = list(filter(lambda path: not path.endswith(".ipython"), local_sys_paths))
        ## remove virtual (and real) environments
        local_sys_paths = list(filter(lambda path: not _utils.is_in_venv(path), local_sys_paths))

        # get paths to files within
        if paths is None:
            # Python files within filtered sys.path dirs
            paths = local_sys_paths
            extensions = ['py', 'pyc', 'pyo']
        else:
            # all user-specified files
            paths = paths
            extensions = None
        local_filepaths = _utils.find_filepaths(
            paths, extensions=extensions,
            include_hidden=True,
            include_venv=False,  # ignore virtual environments nested within
        )
        ## remove .git
        local_filepaths = set(filter(lambda path: not path.endswith(".git") and ".git/" not in path,
                                      local_filepaths))

        # obtain deepest common directory
        #     This directory on the local system will be mirrored in `_CUSTOM_MODULES_DIR` in
        #     deployment.
        curr_dir = os.path.join(os.getcwd(), "")
        paths_plus = list(local_filepaths) + [curr_dir]
        common_prefix = os.path.commonprefix(paths_plus)
        common_dir = os.path.dirname(common_prefix)

        # replace `common_dir` with `_CUSTOM_MODULES_DIR` for deployment sys.path
        depl_sys_paths = list(map(lambda path: os.path.relpath(path, common_dir), local_sys_paths + forced_local_sys_paths))
        depl_sys_paths = list(map(lambda path: os.path.join(_CUSTOM_MODULES_DIR, path), depl_sys_paths))

        bytestream = six.BytesIO()
        with zipfile.ZipFile(bytestream, 'w') as zipf:
            for filepath in local_filepaths:
                arcname = os.path.relpath(filepath, common_dir)  # filepath relative to archive root
                try:
                    zipf.write(filepath, arcname)
                except:
                    # maybe file has corrupt metadata; try reading then writing contents
                    with open(filepath, 'rb') as f:
                        zipf.writestr(
                            _artifact_utils.global_read_zipinfo(arcname),
                            f.read(),
                        )

            # add verta config file for sys.path and chdir
            working_dir = os.path.join(_CUSTOM_MODULES_DIR, os.path.relpath(curr_dir, common_dir))
            zipf.writestr(
                _artifact_utils.global_read_zipinfo("_verta_config.py"),
                six.ensure_binary('\n'.join([
                    "import os, sys",
                    "",
                    "",
                    "sys.path = sys.path[:1] + {} + sys.path[1:]".format(depl_sys_paths),
                    "",
                    "try:",
                    "    os.makedirs(\"{}\")".format(working_dir),
                    "except OSError:  # already exists",
                    "    pass",
                    "os.chdir(\"{}\")".format(working_dir),
                ]))
            )

            # add __init__.py
            init_filename = "__init__.py"
            if init_filename not in zipf.namelist():
                zipf.writestr(
                    _artifact_utils.global_read_zipinfo(init_filename),
                    b"",
                )

        bytestream.seek(0)

        return bytestream
Beispiel #4
0
    def get_code(self):
        """
        Gets the code version.

        Returns
        -------
        dict or zipfile.ZipFile
            Either:
                - a dictionary containing Git snapshot information with at most the following items:
                    - **filepaths** (*list of str*)
                    - **repo_url** (*str*) – Remote repository URL
                    - **commit_hash** (*str*) – Commit hash
                    - **is_dirty** (*bool*)
                - a `ZipFile <https://docs.python.org/3/library/zipfile.html#zipfile.ZipFile>`_
                  containing Python source code files

        """
        # TODO: remove this circular dependency
        from ._project import Project
        from ._experiment import Experiment
        from ._experimentrun import ExperimentRun
        if isinstance(self, Project):  # TODO: not this
            Message = self._service.GetProjectCodeVersion
            endpoint = "getProjectCodeVersion"
        elif isinstance(self, Experiment):
            Message = self._service.GetExperimentCodeVersion
            endpoint = "getExperimentCodeVersion"
        elif isinstance(self, ExperimentRun):
            Message = self._service.GetExperimentRunCodeVersion
            endpoint = "getExperimentRunCodeVersion"
        msg = Message(id=self.id)
        data = _utils.proto_to_json(msg)
        response = _utils.make_request("GET",
                                       self._request_url.format(endpoint),
                                       self._conn,
                                       params=data)
        _utils.raise_for_http_error(response)

        response_msg = _utils.json_to_proto(_utils.body_to_json(response),
                                            Message.Response)
        code_ver_msg = response_msg.code_version
        which_code = code_ver_msg.WhichOneof('code')
        if which_code == 'git_snapshot':
            git_snapshot_msg = code_ver_msg.git_snapshot
            git_snapshot = {}
            if git_snapshot_msg.filepaths:
                git_snapshot['filepaths'] = git_snapshot_msg.filepaths
            if git_snapshot_msg.repo:
                git_snapshot['repo_url'] = git_snapshot_msg.repo
            if git_snapshot_msg.hash:
                git_snapshot['commit_hash'] = git_snapshot_msg.hash
            if git_snapshot_msg.is_dirty != _CommonCommonService.TernaryEnum.UNKNOWN:
                git_snapshot[
                    'is_dirty'] = git_snapshot_msg.is_dirty == _CommonCommonService.TernaryEnum.TRUE
            return git_snapshot
        elif which_code == 'code_archive':
            # download artifact from artifact store
            # pylint: disable=no-member
            # this method should only be called on ExperimentRun, which does have _get_url_for_artifact()
            url = self._get_url_for_artifact(
                "verta_code_archive", "GET",
                code_ver_msg.code_archive.artifact_type).url

            response = _utils.make_request("GET", url, self._conn)
            _utils.raise_for_http_error(response)

            code_archive = six.BytesIO(response.content)
            return zipfile.ZipFile(
                code_archive, 'r')  # TODO: return a util class instead, maybe
        else:
            raise RuntimeError("unable find code in response")
Beispiel #5
0
    def log_code(self,
                 exec_path=None,
                 repo_url=None,
                 commit_hash=None,
                 overwrite=False,
                 is_dirty=None,
                 autocapture=True):
        """
        Logs the code version.

        A code version is either information about a Git snapshot or a bundle of Python source code files.

        `repo_url` and `commit_hash` can only be set if `use_git` was set to ``True`` in the Client.

        Parameters
        ----------
        exec_path : str, optional
            Filepath to the executable Python script or Jupyter notebook. If no filepath is provided,
            the Client will make its best effort to find the currently running script/notebook file.
        repo_url : str, optional
            URL for a remote Git repository containing `commit_hash`. If no URL is provided, the Client
            will make its best effort to find it.
        commit_hash : str, optional
            Git commit hash associated with this code version. If no hash is provided, the Client will
            make its best effort to find it.
        overwrite : bool, default False
            Whether to allow overwriting a code version.
        is_dirty : bool, optional
            Whether git status is dirty relative to `commit_hash`. If not provided, the Client will
            make its best effort to find it.
        autocapture : bool, default True
            Whether to enable the automatic capturing behavior of parameters above in git mode.

        Examples
        --------
        With ``Client(use_git=True)`` (default):

            Log Git snapshot information, plus the location of the currently executing notebook/script
            relative to the repository root:

            .. code-block:: python

                run.log_code()
                run.get_code()
                # {'exec_path': 'comparison/outcomes/classification.ipynb',
                #  'repo_url': '[email protected]:VertaAI/experiments.git',
                #  'commit_hash': 'f99abcfae6c3ce6d22597f95ad6ef260d31527a6',
                #  'is_dirty': False}

            Log Git snapshot information, plus the location of a specific source code file relative
            to the repository root:

            .. code-block:: python

                run.log_code("../trainer/training_pipeline.py")
                run.get_code()
                # {'exec_path': 'comparison/trainer/training_pipeline.py',
                #  'repo_url': '[email protected]:VertaAI/experiments.git',
                #  'commit_hash': 'f99abcfae6c3ce6d22597f95ad6ef260d31527a6',
                #  'is_dirty': False}

        With ``Client(use_git=False)``:

            Find and upload the currently executing notebook/script:

            .. code-block:: python

                run.log_code()
                zip_file = run.get_code()
                zip_file.printdir()
                # File Name                          Modified             Size
                # classification.ipynb        2019-07-10 17:18:24        10287

            Upload a specific source code file:

            .. code-block:: python

                run.log_code("../trainer/training_pipeline.py")
                zip_file = run.get_code()
                zip_file.printdir()
                # File Name                          Modified             Size
                # training_pipeline.py        2019-05-31 10:34:44          964

        """
        if self._conf.use_git and autocapture:
            # verify Git
            try:
                repo_root_dir = _git_utils.get_git_repo_root_dir()
            except OSError:
                # don't halt execution
                print(
                    "unable to locate git repository; you may be in an unsupported environment;\n"
                    "    consider using `autocapture=False` to manually enter values"
                )
                return
                # six.raise_from(OSError("failed to locate git repository; please check your working directory"),
                #                None)
            print("Git repository successfully located at {}".format(
                repo_root_dir))
        if not self._conf.use_git:
            if repo_url is not None or commit_hash is not None:
                raise ValueError(
                    "`repo_url` and `commit_hash` can only be set if `use_git` was set to True in the Client"
                )
            if not autocapture:  # user passed `False`
                raise ValueError(
                    "`autocapture` is only applicable if `use_git` was set to True in the Client"
                )

        if autocapture:
            if exec_path is None:
                # find dynamically
                try:
                    exec_path = _utils.get_notebook_filepath()
                except (ImportError, OSError):  # notebook not found
                    try:
                        exec_path = _utils.get_script_filepath()
                    except OSError:  # script not found
                        print("unable to find code file; skipping")
            else:
                exec_path = os.path.expanduser(exec_path)
                if not os.path.isfile(exec_path):
                    raise ValueError(
                        "`exec_path` \"{}\" must be a valid filepath".format(
                            exec_path))

        # TODO: remove this circular dependency
        from ._project import Project
        from ._experiment import Experiment
        from ._experimentrun import ExperimentRun
        if isinstance(self, Project):  # TODO: not this
            Message = self._service.LogProjectCodeVersion
            endpoint = "logProjectCodeVersion"
        elif isinstance(self, Experiment):
            Message = self._service.LogExperimentCodeVersion
            endpoint = "logExperimentCodeVersion"
        elif isinstance(self, ExperimentRun):
            Message = self._service.LogExperimentRunCodeVersion
            endpoint = "logExperimentRunCodeVersion"
        msg = Message(id=self.id)

        if overwrite:
            if isinstance(self, ExperimentRun):
                msg.overwrite = True
            else:
                raise ValueError(
                    "`overwrite=True` is currently only supported for ExperimentRun"
                )

        if self._conf.use_git:
            if autocapture:
                try:
                    # adjust `exec_path` to be relative to repo root
                    exec_path = os.path.relpath(
                        exec_path, _git_utils.get_git_repo_root_dir())
                except OSError as e:
                    print(
                        "{}; logging absolute path to file instead".format(e))
                    exec_path = os.path.abspath(exec_path)
            if exec_path:
                msg.code_version.git_snapshot.filepaths.append(exec_path)

            from verta.code import _git  # avoid Python 2 top-level circular import
            code_ver = _git.Git(
                repo_url=repo_url,
                commit_hash=commit_hash,
                is_dirty=is_dirty,
                autocapture=autocapture,
            )
            msg.code_version.git_snapshot.repo = code_ver.repo_url or ""
            msg.code_version.git_snapshot.hash = code_ver.commit_hash or ""
            if not autocapture and is_dirty is None:
                msg.code_version.git_snapshot.is_dirty = _CommonCommonService.TernaryEnum.UNKNOWN
            elif code_ver.is_dirty:
                msg.code_version.git_snapshot.is_dirty = _CommonCommonService.TernaryEnum.TRUE
            else:
                msg.code_version.git_snapshot.is_dirty = _CommonCommonService.TernaryEnum.FALSE
        else:  # log code as Artifact
            if exec_path is None:
                # don't halt execution
                print(
                    "unable to find code file; you may be in an unsupported environment"
                )
                return
                # raise RuntimeError("unable to find code file; you may be in an unsupported environment")

            # write ZIP archive
            zipstream = six.BytesIO()
            with zipfile.ZipFile(zipstream, 'w') as zipf:
                filename = os.path.basename(exec_path)
                if exec_path.endswith(".ipynb"):
                    try:
                        saved_notebook = _utils.save_notebook(exec_path)
                    except:  # failed to save
                        print("unable to automatically save Notebook;"
                              " logging latest checkpoint from disk")
                        zipf.write(exec_path, filename)
                    else:
                        zipf.writestr(
                            _artifact_utils.global_read_zipinfo(filename),
                            six.ensure_binary(saved_notebook.read()),
                        )
                else:
                    zipf.write(exec_path, filename)
            zipstream.seek(0)

            key = 'code'
            extension = 'zip'

            artifact_hash = _artifact_utils.calc_sha256(zipstream)
            basename = key + os.extsep + extension
            artifact_path = os.path.join(artifact_hash, basename)

            msg.code_version.code_archive.path = artifact_path
            msg.code_version.code_archive.path_only = False
            msg.code_version.code_archive.artifact_type = _CommonCommonService.ArtifactTypeEnum.CODE
            msg.code_version.code_archive.filename_extension = extension
        # TODO: check if we actually have any loggable information
        msg.code_version.date_logged = _utils.now()

        data = _utils.proto_to_json(msg)
        response = _utils.make_request("POST",
                                       self._request_url.format(endpoint),
                                       self._conn,
                                       json=data)
        if not response.ok:
            if response.status_code == 409:
                raise ValueError("a code version has already been logged")
            else:
                _utils.raise_for_http_error(response)

        if msg.code_version.WhichOneof("code") == 'code_archive':
            # upload artifact to artifact store
            # pylint: disable=no-member
            # this method should only be called on ExperimentRun, which does have _get_url_for_artifact()
            url = self._get_url_for_artifact(
                "verta_code_archive", "PUT",
                msg.code_version.code_archive.artifact_type).url

            response = _utils.make_request("PUT",
                                           url,
                                           self._conn,
                                           data=zipstream)
            _utils.raise_for_http_error(response)
Beispiel #6
0
    def _upload_artifact(self,
                         key,
                         file_handle,
                         artifact_type,
                         part_size=_artifact_utils._64MB):
        file_handle.seek(0)

        # check if multipart upload ok
        url_for_artifact = self._get_url_for_artifact(key,
                                                      "PUT",
                                                      artifact_type,
                                                      part_num=1)

        print("uploading {} to Registry".format(key))
        if url_for_artifact.multipart_upload_ok:
            # TODO: parallelize this
            file_parts = iter(lambda: file_handle.read(part_size), b"")
            for part_num, file_part in enumerate(file_parts, start=1):
                print("uploading part {}".format(part_num), end="\r")

                # get presigned URL
                url = self._get_url_for_artifact(key,
                                                 "PUT",
                                                 artifact_type,
                                                 part_num=part_num).url

                # wrap file part into bytestream to avoid OverflowError
                #     Passing a bytestring >2 GB (num bytes > max val of int32) directly to
                #     ``requests`` will overwhelm CPython's SSL lib when it tries to sign the
                #     payload. But passing a buffered bytestream instead of the raw bytestring
                #     indicates to ``requests`` that it should perform a streaming upload via
                #     HTTP/1.1 chunked transfer encoding and avoid this issue.
                #     https://github.com/psf/requests/issues/2717
                part_stream = six.BytesIO(file_part)

                # upload part
                response = _utils.make_request("PUT",
                                               url,
                                               self._conn,
                                               data=part_stream)
                _utils.raise_for_http_error(response)

                # commit part
                url = "{}://{}/api/v1/registry/model_versions/{}/commitArtifactPart".format(
                    self._conn.scheme, self._conn.socket, self.id)
                msg = _RegistryService.CommitArtifactPart(
                    model_version_id=self.id, key=key)
                msg.artifact_part.part_number = part_num
                msg.artifact_part.etag = response.headers["ETag"]
                data = _utils.proto_to_json(msg)
                response = _utils.make_request("POST",
                                               url,
                                               self._conn,
                                               json=data)
                _utils.raise_for_http_error(response)
            print()

            # complete upload
            url = "{}://{}/api/v1/registry/model_versions/{}/commitMultipartArtifact".format(
                self._conn.scheme, self._conn.socket, self.id)
            msg = _RegistryService.CommitMultipartArtifact(
                model_version_id=self.id, key=key)
            data = _utils.proto_to_json(msg)
            response = _utils.make_request("POST", url, self._conn, json=data)
            _utils.raise_for_http_error(response)
        else:
            # upload full artifact
            if url_for_artifact.fields:
                # if fields were returned by backend, make a POST request and supply them as form fields
                response = _utils.make_request(
                    "POST",
                    url_for_artifact.url,
                    self._conn,
                    # requests uses the `files` parameter for sending multipart/form-data POSTs.
                    #     https://stackoverflow.com/a/12385661/8651995
                    # the file contents must be the final form field
                    #     https://docs.aws.amazon.com/AmazonS3/latest/dev/HTTPPOSTForms.html#HTTPPOSTFormFields
                    files=list(url_for_artifact.fields.items()) +
                    [("file", file_handle)],
                )
            else:
                response = _utils.make_request("PUT",
                                               url_for_artifact.url,
                                               self._conn,
                                               data=file_handle)
            _utils.raise_for_http_error(response)

        print("upload complete")
    def _upload_artifact(self,
                         dataset_component_path,
                         file_handle,
                         part_size=_artifact_utils._64MB):
        """
        Uploads `file_handle` to ModelDB artifact store.

        Parameters
        ----------
        dataset_component_path : str
            Filepath in dataset component blob.
        file_handle : file-like
            Artifact to be uploaded.
        part_size : int, default 64 MB
            If using multipart upload, number of bytes to upload per part.

        """
        file_handle.seek(0)

        # check if multipart upload ok
        url_for_artifact = self._get_url_for_artifact(dataset_component_path,
                                                      "PUT",
                                                      part_num=1)

        print("uploading {} to ModelDB".format(dataset_component_path))
        if url_for_artifact.multipart_upload_ok:
            # TODO: parallelize this
            file_parts = iter(lambda: file_handle.read(part_size), b'')
            for part_num, file_part in enumerate(file_parts, start=1):
                print("uploading part {}".format(part_num), end='\r')

                # get presigned URL
                url = self._get_url_for_artifact(dataset_component_path,
                                                 "PUT",
                                                 part_num=part_num).url

                # wrap file part into bytestream to avoid OverflowError
                #     Passing a bytestring >2 GB (num bytes > max val of int32) directly to
                #     ``requests`` will overwhelm CPython's SSL lib when it tries to sign the
                #     payload. But passing a buffered bytestream instead of the raw bytestring
                #     indicates to ``requests`` that it should perform a streaming upload via
                #     HTTP/1.1 chunked transfer encoding and avoid this issue.
                #     https://github.com/psf/requests/issues/2717
                part_stream = six.BytesIO(file_part)

                # upload part
                response = _utils.make_request("PUT",
                                               url,
                                               self._conn,
                                               data=part_stream)
                _utils.raise_for_http_error(response)

                # commit part
                url = "{}://{}/api/v1/modeldb/dataset-version/commitVersionedDatasetBlobArtifactPart".format(
                    self._conn.scheme,
                    self._conn.socket,
                )
                msg = _DatasetVersionService.CommitVersionedDatasetBlobArtifactPart(
                    dataset_version_id=self.id,
                    path_dataset_component_blob_path=dataset_component_path,
                )
                msg.artifact_part.part_number = part_num
                msg.artifact_part.etag = response.headers['ETag']
                data = _utils.proto_to_json(msg)
                response = _utils.make_request("POST",
                                               url,
                                               self._conn,
                                               json=data)
                _utils.raise_for_http_error(response)
            print()

            # complete upload
            url = "{}://{}/api/v1/modeldb/dataset-version/commitMultipartVersionedDatasetBlobArtifact".format(
                self._conn.scheme,
                self._conn.socket,
            )
            msg = _DatasetVersionService.CommitMultipartVersionedDatasetBlobArtifact(
                dataset_version_id=self.id,
                path_dataset_component_blob_path=dataset_component_path,
            )
            data = _utils.proto_to_json(msg)
            response = _utils.make_request("POST", url, self._conn, json=data)
            _utils.raise_for_http_error(response)
        else:
            # upload full artifact
            if url_for_artifact.fields:
                # if fields were returned by backend, make a POST request and supply them as form fields
                response = _utils.make_request(
                    "POST",
                    url_for_artifact.url,
                    self._conn,
                    # requests uses the `files` parameter for sending multipart/form-data POSTs.
                    #     https://stackoverflow.com/a/12385661/8651995
                    # the file contents must be the final form field
                    #     https://docs.aws.amazon.com/AmazonS3/latest/dev/HTTPPOSTForms.html#HTTPPOSTFormFields
                    files=list(url_for_artifact.fields.items()) +
                    [('file', file_handle)],
                )
            else:
                response = _utils.make_request("PUT",
                                               url_for_artifact.url,
                                               self._conn,
                                               data=file_handle)
            _utils.raise_for_http_error(response)

        print("upload complete")