def _maybe_load_from_cloud(self, checkpoint_path: str) -> bool: if not self.uses_cloud_checkpointing: return False rel_checkpoint_dir = TrainableUtil.find_rel_checkpoint_dir( self.logdir, checkpoint_path) external_uri = os.path.join(self.remote_checkpoint_dir, rel_checkpoint_dir) local_dir = os.path.join(self.logdir, rel_checkpoint_dir) if self.custom_syncer: # Only keep for backwards compatibility self.custom_syncer.sync_down(remote_dir=external_uri, local_dir=local_dir) self.custom_syncer.wait_or_retry() return True checkpoint = Checkpoint.from_uri(external_uri) retry_fn( lambda: checkpoint.to_directory(local_dir), subprocess.CalledProcessError, num_retries=3, sleep_time=1, ) return True
def test_find_rel_checkpoint_dir(checkpoint_path, logdir): assert ( TrainableUtil.find_rel_checkpoint_dir(logdir, checkpoint_path) == "checkpoint0" )