예제 #1
0
def get_data_trains():
    data_path = StorageManager.get_local_copy(
        remote_url=
        "https://storage.googleapis.com/tensorflow/tf-keras-datasets/mnist.npz",
        name="mnist")

    with np.load(data_path, allow_pickle=True) as f:
        x_train, y_train = f['x_train'], f['y_train']
        x_test, y_test = f['x_test'], f['y_test']
        return (x_train, y_train), (x_test, y_test)
예제 #2
0
 def get_local_copy(self, extract_archive=True):
     """
     :param bool extract_archive: If True and artifact is of type 'archive' (compressed folder)
         The returned path will be a temporary folder containing the archive content
     :return: a local path to a downloaded copy of the artifact
     """
     from trains.storage import StorageManager
     return StorageManager.get_local_copy(remote_url=self.url,
                                          extract_archive=extract_archive
                                          and self.type == 'archive',
                                          name=self.name)
예제 #3
0
    def get_local_copy(self, extract_archive=True, raise_on_error=False):
        # type: (bool, bool) -> str
        """
        :param bool extract_archive: If True and artifact is of type 'archive' (compressed folder)
            The returned path will be a temporary folder containing the archive content
        :param bool raise_on_error: If True and the artifact could not be downloaded,
            raise ValueError, otherwise return None on failure and output log warning.
        :return: a local path to a downloaded copy of the artifact
        """
        from trains.storage import StorageManager
        local_copy = StorageManager.get_local_copy(
            remote_url=self.url,
            extract_archive=extract_archive and self.type == 'archive',
            name=self.name)
        if raise_on_error and local_copy is None:
            raise ValueError(
                "Could not retrieve a local copy of artifact {}, failed downloading {}"
                .format(self.name, self.url))

        return local_copy
예제 #4
0
    def get_local_copy(self, extract_archive=True):
        """
        :param bool extract_archive: If True and artifact is of type 'archive' (compressed folder)
            The returned path will be a temporary folder containing the archive content
        :return: a local path to a downloaded copy of the artifact
        """
        from trains.storage import StorageManager
        local_path = StorageManager.get_local_copy(self.url)
        if local_path and extract_archive and self.type == 'archive':
            temp_folder = None
            try:
                temp_folder = mkdtemp(prefix='artifact_',
                                      suffix='.archive_' + self.name)
                ZipFile(local_path).extractall(path=temp_folder)
            except Exception:
                try:
                    if temp_folder:
                        Path(temp_folder).rmdir()
                except Exception:
                    pass
                return local_path
            return temp_folder

        return local_path