Пример #1
0
def get_data_trains():
    data_path = StorageManager.get_local_copy(
        remote_url=
        "https://storage.googleapis.com/tensorflow/tf-keras-datasets/mnist.npz",
        name="mnist")

    with np.load(data_path, allow_pickle=True) as f:
        x_train, y_train = f['x_train'], f['y_train']
        x_test, y_test = f['x_test'], f['y_test']
        return (x_train, y_train), (x_test, y_test)
Пример #2
0
 def get_local_copy(self, extract_archive=True):
     """
     :param bool extract_archive: If True and artifact is of type 'archive' (compressed folder)
         The returned path will be a temporary folder containing the archive content
     :return: a local path to a downloaded copy of the artifact
     """
     from trains.storage import StorageManager
     return StorageManager.get_local_copy(remote_url=self.url,
                                          extract_archive=extract_archive
                                          and self.type == 'archive',
                                          name=self.name)
Пример #3
0
    def _upload_local_file(self,
                           local_file,
                           name,
                           delete_after_upload=False,
                           override_filename=None,
                           override_filename_ext=None,
                           wait_on_upload=False):
        # type: (str, str, bool, Optional[str], Optional[str], Optional[bool]) -> str
        """
        Upload local file and return uri of the uploaded file (uploading in the background)
        """
        from trains.storage import StorageManager

        upload_uri = self._task.output_uri or self._task.get_logger(
        ).get_default_upload_destination()
        if not isinstance(local_file, Path):
            local_file = Path(local_file)
        ev = UploadEvent(
            metric='artifacts',
            variant=name,
            image_data=None,
            upload_uri=upload_uri,
            local_image_path=local_file.as_posix(),
            delete_after_upload=delete_after_upload,
            override_filename=override_filename,
            override_filename_ext=override_filename_ext,
            override_storage_key_prefix=self._get_storage_uri_prefix())
        _, uri = ev.get_target_full_upload_uri(upload_uri, quote_uri=False)

        # send for upload
        # noinspection PyProtectedMember
        if wait_on_upload:
            StorageManager.upload_file(local_file, uri)
        else:
            self._task._reporter._report(ev)

        _, quoted_uri = ev.get_target_full_upload_uri(upload_uri)

        return quoted_uri
Пример #4
0
    def get_local_copy(self, extract_archive=True, raise_on_error=False):
        # type: (bool, bool) -> str
        """
        :param bool extract_archive: If True and artifact is of type 'archive' (compressed folder)
            The returned path will be a temporary folder containing the archive content
        :param bool raise_on_error: If True and the artifact could not be downloaded,
            raise ValueError, otherwise return None on failure and output log warning.
        :return: a local path to a downloaded copy of the artifact
        """
        from trains.storage import StorageManager
        local_copy = StorageManager.get_local_copy(
            remote_url=self.url,
            extract_archive=extract_archive and self.type == 'archive',
            name=self.name)
        if raise_on_error and local_copy is None:
            raise ValueError(
                "Could not retrieve a local copy of artifact {}, failed downloading {}"
                .format(self.name, self.url))

        return local_copy
Пример #5
0
    def get_local_copy(self, extract_archive=True):
        """
        :param bool extract_archive: If True and artifact is of type 'archive' (compressed folder)
            The returned path will be a temporary folder containing the archive content
        :return: a local path to a downloaded copy of the artifact
        """
        from trains.storage import StorageManager
        local_path = StorageManager.get_local_copy(self.url)
        if local_path and extract_archive and self.type == 'archive':
            temp_folder = None
            try:
                temp_folder = mkdtemp(prefix='artifact_',
                                      suffix='.archive_' + self.name)
                ZipFile(local_path).extractall(path=temp_folder)
            except Exception:
                try:
                    if temp_folder:
                        Path(temp_folder).rmdir()
                except Exception:
                    pass
                return local_path
            return temp_folder

        return local_path