def _restore(self, trial, checkpoint=None, block=False) -> Optional[RunningJob]: """Restores training state from a given model checkpoint. Args: trial (Trial): The trial to be restored. checkpoint (Checkpoint): The checkpoint to restore from. If None, the most recent PERSISTENT checkpoint is used. Defaults to None. block (bool): Whether or not to block on restore before returning. Raises: RuntimeError: This error is raised if no runner is found. AbortTrialExecution: This error is raised if the trial is ineligible for restoration, given the Tune input arguments. """ if checkpoint is None or checkpoint.value is None: checkpoint = trial.checkpoint if checkpoint.value is None: return if trial.runner is None: raise RuntimeError( "Trial {}: Unable to restore - no runner found.".format(trial)) value = checkpoint.value if checkpoint.storage == Checkpoint.MEMORY: logger.debug("Trial %s: Attempting restore from object", trial) # Note that we don't store the remote since in-memory checkpoints # don't guarantee fault tolerance and don't need to be waited on. with _change_working_directory(trial): trial.runner.restore_from_object.remote(value) else: logger.debug("Trial %s: Attempting restore from %s", trial, value) if issubclass(trial.get_trainable_cls(), DurableTrainable): with _change_working_directory(trial): remote = trial.runner.restore.remote(value) elif trial.sync_on_checkpoint: # This provides FT backwards compatibility in the # case where a DurableTrainable is not provided. logger.warning("Trial %s: Reading checkpoint into memory.", trial) data_dict = TrainableUtil.pickle_checkpoint(value) with _change_working_directory(trial): remote = trial.runner.restore_from_object.remote(data_dict) else: raise AbortTrialExecution( "Pass in `sync_on_checkpoint=True` for driver-based trial" "restoration. Pass in an `upload_dir` and a Trainable " "extending `DurableTrainable` for remote storage-based " "restoration") if block: ray.get(remote) else: trial.restoring_from = checkpoint running_job = RunningJob(trial, remote) self.jobs_running[remote] = running_job return running_job
def save_to_object(self): checkpoint_path = self.save() data_dict = TrainableUtil.pickle_checkpoint(checkpoint_path) out = io.BytesIO() if len(data_dict) > 10e6: # getting pretty large logger.info("Checkpoint size is {} bytes".format(len(data_dict))) out.write(data_dict) return out.getvalue()
def testPickleCheckpoint(self): for i in range(5): path = os.path.join(self.checkpoint_dir, str(i)) with open(path, "w") as f: f.write(str(i)) checkpoint_path = os.path.join(self.checkpoint_dir, "0") data_dict = TrainableUtil.pickle_checkpoint(checkpoint_path) loaded = pickle.loads(data_dict) checkpoint_name = os.path.basename(checkpoint_path) self.assertEqual(loaded["checkpoint_name"], checkpoint_name) for i in range(5): path = os.path.join(self.checkpoint_dir, str(i)) self.assertEquals(loaded["data"][str(i)], open(path, "rb").read())
def restore(self, trial, checkpoint=None): """Restores training state from a given model checkpoint. Raises: RuntimeError: This error is raised if no runner is found. AbortTrialExecution: This error is raised if the trial is ineligible for restoration, given the Tune input arguments. """ if checkpoint is None or checkpoint.value is None: checkpoint = trial.checkpoint if checkpoint.value is None: return if trial.runner is None: raise RuntimeError( "Trial {}: Unable to restore - no runner found.".format(trial)) value = checkpoint.value if checkpoint.storage == Checkpoint.MEMORY: logger.debug("Trial %s: Attempting restore from object", trial) # Note that we don't store the remote since in-memory checkpoints # don't guarantee fault tolerance and don't need to be waited on. trial.runner.restore_from_object.remote(value) else: logger.debug("Trial %s: Attempting restore from %s", trial, value) if issubclass(trial.get_trainable_cls(), DurableTrainable): remote = trial.runner.restore.remote(value) elif trial.sync_on_checkpoint: # This provides FT backwards compatibility in the # case where a DurableTrainable is not provided. logger.warning("Trial %s: Reading checkpoint into memory.", trial) data_dict = TrainableUtil.pickle_checkpoint(value) remote = trial.runner.restore_from_object.remote(data_dict) else: raise AbortTrialExecution( "Pass in `sync_on_checkpoint=True` for driver-based trial" "restoration. Pass in an `upload_dir` and a Trainable " "extending `DurableTrainable` for remote storage-based " "restoration") self._running[remote] = trial trial.restoring_from = checkpoint