def to_directory(self, path: Optional[str] = None) -> str: """Write checkpoint data to directory. Args: path (str): Target directory to restore data in. Returns: str: Directory containing checkpoint data. """ path = path if path is not None else _temporary_checkpoint_dir() os.makedirs(path, exist_ok=True) # Drop marker open(os.path.join(path, ".is_checkpoint"), "a").close() if self._data_dict or self._obj_ref: # This is a object ref or dict data_dict = self.to_dict() if _FS_CHECKPOINT_KEY in data_dict: # This used to be a true fs checkpoint, so restore _unpack(data_dict[_FS_CHECKPOINT_KEY], path) else: # This is a dict checkpoint. Dump data into checkpoint.pkl checkpoint_data_path = os.path.join( path, _DICT_CHECKPOINT_FILE_NAME) with open(checkpoint_data_path, "wb") as f: pickle.dump(data_dict, f) else: # This is either a local fs, remote node fs, or external fs local_path = self._local_path external_path = _get_external_path(self._uri) if local_path: if local_path != path: # If this exists on the local path, just copy over if path and os.path.exists(path): shutil.rmtree(path) shutil.copytree(local_path, path) elif external_path: # If this exists on external storage (e.g. cloud), download download_from_bucket(bucket=external_path, local_path=path) else: raise RuntimeError( f"No valid location found for checkpoint {self}: {self._uri}" ) return path
def download( self, cloud_path: Optional[str] = None, local_path: Optional[str] = None, overwrite: bool = False, ) -> str: """Download checkpoint from cloud. This will fetch the checkpoint directory from cloud storage and save it to ``local_path``. If a ``local_path`` argument is provided and ``self.local_path`` is unset, it will be set to ``local_path``. Args: cloud_path: Cloud path to load checkpoint from. Defaults to ``self.cloud_path``. local_path: Local path to save checkpoint at. Defaults to ``self.local_path``. overwrite: If True, overwrites potential existing local checkpoint. If False, exits if ``self.local_dir`` already exists and has files in it. """ cloud_path = cloud_path or self.cloud_path if not cloud_path: raise RuntimeError( "Could not download trial checkpoint: No cloud " "path is set. Fix this by either passing a " "`cloud_path` to your call to `download()` or by " "passing a `cloud_path` into the constructor. The latter " "should automatically be done if you pass the correct " "`tune.SyncConfig`.") local_path = local_path or self.local_path if not local_path: raise RuntimeError( "Could not download trial checkpoint: No local " "path is set. Fix this by either passing a " "`local_path` to your call to `download()` or by " "passing a `local_path` into the constructor.") # Only update local path if unset if not self.local_path: self.local_path = local_path if (not overwrite and os.path.exists(local_path) and len(os.listdir(local_path)) > 0): # Local path already exists and we should not overwrite, # so return. return local_path # Else: Actually download # Delete existing dir shutil.rmtree(local_path, ignore_errors=True) # Re-create os.makedirs(local_path, 0o755, exist_ok=True) # Here we trigger the actual download download_from_bucket(cloud_path, local_path) # Local dir exists and is not empty return local_path