def open_dataset(host, path, dataset_factory, args=None, kwargs=None): """ Downloads file using SFTP and opens dataset using a temporary directory for file transfer. Args: host: IP address of the host. path: The path for which to list the files dataset_class: The class used to read in the file. args: List of positional arguments passed to the dataset_factory method after the downloaded file. kwargs: Dictionary of keyword arguments passed to the dataset factory call. Returns: An object created using the provided dataset_factory using the downloaded file as first arguments and the provided args and kwargs as positional and keyword arguments. """ if args is None: args = [] if not isinstance(args, Iterable): raise ValueError("Provided postitional arguments 'args' must be " "iterable.") if kwargs is None: kwargs = {} if not isinstance(kwargs, Mapping): raise ValueError("Provided postitional arguments 'kwargs' must be " "a mapping.") with sftp.download_file(host, path) as file: dataset = dataset_factory(file, *args, **kwargs) return dataset
def read_file(path, *args, **kwargs): """ Generic function to open files. Currently supports opening files on the local system as well as on a remote machine via SFTP. """ if isinstance(path, PurePath): yield open(path, *args, **kwargs) return url = urlparse(path) if url.netloc == "": yield open(path, *args, **kwargs) return if url.scheme == "sftp": host = url.netloc if host == "": raise InvalidURL( f"No host in SFTP URL." f"To load a file using SFTP, the URL must be of the form " f"'sftp://<host>/<path>'.") with sftp.download_file(host, url.path) as file: yield open(file, *args, **kwargs) return raise InvalidURL(f"The provided protocol '{url.scheme}' is not supported.")
def test_download_file(): """ Ensure that downloading of files work and the data is cleaned up after usage. """ host = "129.16.35.202" path = "/mnt/array1/share/Datasets/test/data_0.npz" tmp_file = None with sftp.download_file(host, path) as file: tmp_file = file data = np.load(file) assert np.all(np.isclose(data["x"], 0.0)) assert not tmp_file.exists()