Ejemplo n.º 1
0
    def test_download_file_no_collision(self, experiment_run, dir_and_files,
                                        in_tempdir):
        source_dirpath, _ = dir_and_files
        key = "artifact"

        # create archive and move into cwd so it's deleted on teardown
        filepath = os.path.abspath("archive.zip")
        temp_zip = _artifact_utils.zip_dir(source_dirpath)
        os.rename(temp_zip.name, filepath)

        # upload and download file
        experiment_run.log_artifact(key, filepath)
        download_url = experiment_run._get_url_for_artifact(key, "GET").url
        response = requests.get(download_url)
        downloaded_filepath = _request_utils.download_file(
            response,
            filepath,
            overwrite_ok=False,
        )
        downloaded_filepath = os.path.abspath(downloaded_filepath)

        # different names
        assert filepath != downloaded_filepath
        # contents match
        assert filecmp.cmp(filepath, downloaded_filepath)
Ejemplo n.º 2
0
    def test_download_arbitrary_zip(self, experiment_run, dir_and_files, strs, in_tempdir):
        """Model that was originally a ZIP is not unpacked on download."""
        model_dir, _ = dir_and_files
        upload_path, download_path = strs[:2]

        # zip `model_dir` into `upload_path`
        with open(upload_path, 'wb') as f:
            shutil.copyfileobj(
                _artifact_utils.zip_dir(model_dir),
                f,
            )

        experiment_run.log_model(upload_path)
        experiment_run.download_model(download_path)

        assert zipfile.is_zipfile(download_path)
        assert filecmp.cmp(upload_path, download_path)
Ejemplo n.º 3
0
    def test_download_arbitrary_zip(self, model_version, dir_and_files, strs,
                                    in_tempdir):
        """Model that was originally a ZIP is not unpacked on download."""
        model_dir, _ = dir_and_files
        upload_path, download_path = strs[:2]

        # zip `model_dir` into `upload_path`
        with open(upload_path, "wb") as f:
            shutil.copyfileobj(
                _artifact_utils.zip_dir(model_dir),
                f,
            )

        model_version.log_model(upload_path)
        returned_path = model_version.download_model(download_path)
        assert returned_path == os.path.abspath(download_path)

        assert zipfile.is_zipfile(download_path)
        assert filecmp.cmp(upload_path, download_path)
Ejemplo n.º 4
0
    def log_artifact(self, key, artifact, overwrite=False, _extension=None):
        """
        Logs an artifact to this Model Version.

        .. note::

            The following artifact keys are reserved for internal use within the
            Verta system:

            - ``"custom_modules"``
            - ``"model"``
            - ``"model.pkl"``
            - ``"model_api.json"``
            - ``"requirements.txt"``
            - ``"train_data"``
            - ``"tf_saved_model"``
            - ``"setup_script"``

        Parameters
        ----------
        key : str
            Name of the artifact.
        artifact : str or file-like or object
            Artifact or some representation thereof.
                - If str, then it will be interpreted as a filesystem path, its contents read as bytes,
                  and uploaded as an artifact. If it is a directory path, its contents will be zipped.
                - If file-like, then the contents will be read as bytes and uploaded as an artifact.
                - Otherwise, the object will be serialized and uploaded as an artifact.
        overwrite : bool, default False
            Whether to allow overwriting an existing artifact with key `key`.

        """
        # TODO: should validate keys, but can't here because this public
        #       method is also used to log internal artifacts
        # _artifact_utils.validate_key(key)
        if key == self._MODEL_KEY:
            raise ValueError('the key "{}" is reserved for model;'
                             " consider using log_model() instead".format(
                                 self._MODEL_KEY))

        self._fetch_with_no_cache()
        same_key_ind = -1

        for i in range(len(self._msg.artifacts)):
            if self._msg.artifacts[i].key == key:
                if not overwrite:
                    raise ValueError(
                        "The key has been set; consider setting overwrite=True"
                    )
                else:
                    same_key_ind = i
                break

        artifact_type = _CommonCommonService.ArtifactTypeEnum.BLOB

        if isinstance(artifact, six.string_types):
            if os.path.isdir(artifact):  # zip dirpath
                artifact = _artifact_utils.zip_dir(artifact)
            else:  # open filepath
                artifact = open(artifact, "rb")
        artifact_stream, method = _artifact_utils.ensure_bytestream(artifact)

        artifact_msg = self._create_artifact_msg(
            key,
            artifact_stream,
            artifact_type=artifact_type,
            method=method,
            extension=_extension,
        )
        if same_key_ind == -1:
            self._msg.artifacts.append(artifact_msg)
        else:
            self._msg.artifacts[same_key_ind].CopyFrom(artifact_msg)

        self._update(self._msg, method="PUT")
        self._upload_artifact(key,
                              artifact_stream,
                              artifact_type=artifact_type)