class HDFSStorageManager(StorageManager): """ Store and load checkpoints from HDFS. """ def __init__( self, hdfs_url: str, hdfs_path: str, user: Optional[str] = None, temp_dir: Optional[str] = None, ) -> None: super().__init__( temp_dir if temp_dir is not None else tempfile.gettempdir()) self.hdfs_url = hdfs_url self.hdfs_path = hdfs_path self.user = user self.client = InsecureClient(self.hdfs_url, root=self.hdfs_path, user=self.user) def post_store_path(self, storage_id: str, storage_dir: str, metadata: StorageMetadata) -> None: """post_store_path uploads the checkpoint to hdfs and deletes the original files.""" try: logging.info("Uploading storage {} to HDFS".format(storage_id)) result = self.client.upload(metadata, storage_dir) logging.info("Uploaded storage {} to HDFS path {}".format( storage_id, result)) finally: self._remove_checkpoint_directory(metadata.storage_id) @contextlib.contextmanager def restore_path(self, metadata: StorageMetadata) -> Iterator[str]: logging.info("Downloading storage {} from HDFS".format( metadata.storage_id)) self.client.download(metadata.storage_id, self._base_path, overwrite=True) try: yield os.path.join(self._base_path, metadata.storage_id) finally: self._remove_checkpoint_directory(metadata.storage_id) def delete(self, metadata: StorageMetadata) -> None: logging.info("Deleting storage {} from HDFS".format( metadata.storage_id)) self.client.delete(metadata.storage_id, recursive=True)
class HDFSStorageManager(storage.CloudStorageManager): """ Store and load checkpoints from HDFS. """ def __init__( self, hdfs_url: str, hdfs_path: str, user: Optional[str] = None, temp_dir: Optional[str] = None, ) -> None: super().__init__(temp_dir if temp_dir is not None else tempfile.gettempdir()) self.hdfs_url = hdfs_url self.hdfs_path = hdfs_path self.user = user self.client = InsecureClient(self.hdfs_url, root=self.hdfs_path, user=self.user) @util.preserve_random_state def upload(self, src: Union[str, os.PathLike], dst: str) -> None: src = os.fspath(src) logging.info(f"Uploading to HDFS: {dst}") self.client.upload(dst, src) @util.preserve_random_state def download(self, src: str, dst: Union[str, os.PathLike]) -> None: dst = os.fspath(dst) logging.info(f"Downloading {src} from HDFS") self.client.download(src, dst, overwrite=True) @util.preserve_random_state def delete(self, tgt: str) -> None: logging.info(f"Deleting {tgt} from HDFS") self.client.delete(tgt, recursive=True)