def _get_local_path(path: Optional[str]) -> Optional[str]: """Check if path is a local path. Otherwise return None.""" if path is None or is_non_local_path_uri(path): return None if path.startswith("file://"): path = path[7:] if os.path.exists(path): return path return None
def to_uri(self, uri: str) -> str: """Write checkpoint data to location URI (e.g. cloud storage). Args: uri (str): Target location URI to write data to. Returns: str: Cloud location containing checkpoint data. """ if uri.startswith("file://"): local_path = uri[7:] return self.to_directory(local_path) if not is_non_local_path_uri(uri): raise RuntimeError( f"Cannot upload checkpoint to URI: Provided URI " f"does not belong to a registered storage provider: `{uri}`. " f"Hint: {fs_hint(uri)}") with self.as_directory() as local_path: upload_to_uri(local_path=local_path, uri=uri) return uri
def save(self, path: Optional[str] = None, force_download: bool = False): """Save trial checkpoint to directory or cloud storage. If the ``path`` is a local target and the checkpoint already exists on local storage, the local directory is copied. Else, the checkpoint is downloaded from cloud storage. If the ``path`` is a cloud target and the checkpoint does not already exist on local storage, it is downloaded from cloud storage before. That way checkpoints can be transferred across cloud storage providers. Args: path: Path to save checkpoint at. If empty, the default cloud storage path is saved to the default local directory. force_download: If ``True``, forces (re-)download of the checkpoint. Defaults to ``False``. """ temp_dirs = set() # Per default, save cloud checkpoint if not path: if self.cloud_path and self.local_path: path = self.local_path elif not self.cloud_path: raise RuntimeError( "Cannot save trial checkpoint: No cloud path " "found. If the checkpoint is already on the node, " "you can pass a `path` argument to save it at another " "location." ) else: # No self.local_path raise RuntimeError( "Cannot save trial checkpoint: No target path " "specified and no default local directory available. " "Please pass a `path` argument to `save()`." ) elif not self.local_path and not self.cloud_path: raise RuntimeError( f"Cannot save trial checkpoint to cloud target " f"`{path}`: No existing local or cloud path was " f"found. This indicates an error when loading " f"the checkpoints. Please report this issue." ) if is_non_local_path_uri(path): # Storing on cloud if not self.local_path: # No local copy, yet. Download to temp dir local_path = tempfile.mkdtemp(prefix="tune_checkpoint_") temp_dirs.add(local_path) else: local_path = self.local_path if self.cloud_path: # Do not update local path as it might be a temp file local_path = self.download( local_path=local_path, overwrite=force_download ) # Remove pointer to a temporary directory if self.local_path in temp_dirs: self.local_path = None # We should now have a checkpoint available locally if not os.path.exists(local_path) or len(os.listdir(local_path)) == 0: raise RuntimeError( f"No checkpoint found in directory `{local_path}` after " f"download - maybe the bucket is empty or downloading " f"failed?" ) # Only update cloud path if it wasn't set before cloud_path = self.upload( cloud_path=path, local_path=local_path, clean_before=True ) # Clean up temporary directories for temp_dir in temp_dirs: shutil.rmtree(temp_dir) return cloud_path local_path_exists = ( self.local_path and os.path.exists(self.local_path) and len(os.listdir(self.local_path)) > 0 ) # Else: path is a local target if self.local_path and local_path_exists and not force_download: # If we have a local copy, use it if path == self.local_path: # Nothing to do return self.local_path # Both local, just copy tree if os.path.exists(path): shutil.rmtree(path) shutil.copytree(self.local_path, path) return path # Else: Download try: return self.download(local_path=path, overwrite=force_download) except Exception as e: raise RuntimeError( "Cannot save trial checkpoint to local target as downloading " "from cloud failed. Did you pass the correct `SyncConfig`?" ) from e
def _get_external_path(path: Optional[str]) -> Optional[str]: """Check if path is an external path. Otherwise return None.""" if not isinstance(path, str) or not is_non_local_path_uri(path): return None return path