def to_uri(self, uri: str) -> str: """Write checkpoint data to location URI (e.g. cloud storage). ARgs: uri (str): Target location URI to write data to. Returns: str: Cloud location containing checkpoint data. """ if uri.startswith("file://"): local_path = uri[7:] return self.to_directory(local_path) assert is_cloud_target(uri) cleanup = False local_path = self._local_path if not local_path: cleanup = True local_path = self.to_directory() upload_to_bucket(bucket=uri, local_path=local_path) if cleanup: shutil.rmtree(local_path) return uri
def _get_local_path(path: Optional[str]) -> Optional[str]: """Check if path is a local path. Otherwise return None.""" if path is None or is_cloud_target(path): return None if path.startswith("file://"): path = path[7:] if os.path.exists(path): return path return None
def _get_external_path(path: Optional[str]) -> Optional[str]: """Check if path is an external path. Otherwise return None.""" if not isinstance(path, str) or not is_cloud_target(path): return None return path
def save(self, path: Optional[str] = None, force_download: bool = False): """Save trial checkpoint to directory or cloud storage. If the ``path`` is a local target and the checkpoint already exists on local storage, the local directory is copied. Else, the checkpoint is downloaded from cloud storage. If the ``path`` is a cloud target and the checkpoint does not already exist on local storage, it is downloaded from cloud storage before. That way checkpoints can be transferred across cloud storage providers. Args: path: Path to save checkpoint at. If empty, the default cloud storage path is saved to the default local directory. force_download: If ``True``, forces (re-)download of the checkpoint. Defaults to ``False``. """ temp_dirs = set() # Per default, save cloud checkpoint if not path: if self.cloud_path and self.local_path: path = self.local_path elif not self.cloud_path: raise RuntimeError( "Cannot save trial checkpoint: No cloud path " "found. If the checkpoint is already on the node, " "you can pass a `path` argument to save it at another " "location.") else: # No self.local_path raise RuntimeError( "Cannot save trial checkpoint: No target path " "specified and no default local directory available. " "Please pass a `path` argument to `save()`.") elif not self.local_path and not self.cloud_path: raise RuntimeError( f"Cannot save trial checkpoint to cloud target " f"`{path}`: No existing local or cloud path was " f"found. This indicates an error when loading " f"the checkpoints. Please report this issue.") if is_cloud_target(path): # Storing on cloud if not self.local_path: # No local copy, yet. Download to temp dir local_path = tempfile.mkdtemp(prefix="tune_checkpoint_") temp_dirs.add(local_path) else: local_path = self.local_path if self.cloud_path: # Do not update local path as it might be a temp file local_path = self.download(local_path=local_path, overwrite=force_download) # Remove pointer to a temporary directory if self.local_path in temp_dirs: self.local_path = None # We should now have a checkpoint available locally if not os.path.exists(local_path) or len( os.listdir(local_path)) == 0: raise RuntimeError( f"No checkpoint found in directory `{local_path}` after " f"download - maybe the bucket is empty or downloading " f"failed?") # Only update cloud path if it wasn't set before cloud_path = self.upload(cloud_path=path, local_path=local_path, clean_before=True) # Clean up temporary directories for temp_dir in temp_dirs: shutil.rmtree(temp_dir) return cloud_path local_path_exists = (self.local_path and os.path.exists(self.local_path) and len(os.listdir(self.local_path)) > 0) # Else: path is a local target if self.local_path and local_path_exists and not force_download: # If we have a local copy, use it if path == self.local_path: # Nothing to do return self.local_path # Both local, just copy tree if os.path.exists(path): shutil.rmtree(path) shutil.copytree(self.local_path, path) return path # Else: Download try: return self.download(local_path=path, overwrite=force_download) except Exception as e: raise RuntimeError( "Cannot save trial checkpoint to local target as downloading " "from cloud failed. Did you pass the correct `SyncConfig`?" ) from e