def save_all_states_remote(self, trial_state): """ Save all of AdaptDL's job state and return it as an in-memory object.""" checkpoint = save_all_states() parent_dir = TrainableUtil.find_checkpoint_dir(checkpoint) checkpoint_path = TrainableUtil.process_checkpoint(checkpoint, parent_dir, trial_state) checkpoint_obj = TrainableUtil.checkpoint_to_object(checkpoint_path) # Done with the directory, remove shutil.rmtree(checkpoint_path) return checkpoint_obj
def save(self, checkpoint_path=None) -> str: if checkpoint_path: raise ValueError("Checkpoint path should not be used with function API.") checkpoint = self._status_reporter.get_checkpoint() state = self.get_state() if not checkpoint: state.update(iteration=0, timesteps_total=0, episodes_total=0) # We drop a marker here to indicate that the checkpoint is empty checkpoint = FuncCheckpointUtil.mk_null_checkpoint_dir(self.logdir) parent_dir = checkpoint elif isinstance(checkpoint, dict): parent_dir = TrainableUtil.make_checkpoint_dir( self.logdir, index=self.training_iteration ) elif isinstance(checkpoint, str): parent_dir = TrainableUtil.find_checkpoint_dir(checkpoint) # When the trainable is restored, a temporary checkpoint # is created. However, when saved, it should become permanent. # Ideally, there are no save calls upon a temporary # checkpoint, but certain schedulers might. if FuncCheckpointUtil.is_temp_checkpoint_dir(parent_dir): relative_path = os.path.relpath(checkpoint, parent_dir) parent_dir = FuncCheckpointUtil.create_perm_checkpoint( checkpoint_dir=parent_dir, logdir=self.logdir, step=self.training_iteration, ) checkpoint = os.path.abspath(os.path.join(parent_dir, relative_path)) else: raise ValueError( "Provided checkpoint was expected to have " "type (str, dict). Got {}.".format(type(checkpoint)) ) checkpoint_path = TrainableUtil.process_checkpoint( checkpoint, parent_dir, state ) self._postprocess_checkpoint(checkpoint_path) self._maybe_save_to_cloud(parent_dir) return checkpoint_path
def save(self, checkpoint_path=None): if checkpoint_path: raise ValueError( "Checkpoint path should not be used with function API.") checkpoint = self._status_reporter.get_checkpoint() state = self.get_state() if not checkpoint: state.update(iteration=0, timesteps_total=0, episodes_total=0) parent_dir = self.create_default_checkpoint_dir() elif isinstance(checkpoint, dict): parent_dir = TrainableUtil.make_checkpoint_dir( self.logdir, index=self.training_iteration) else: parent_dir = TrainableUtil.find_checkpoint_dir(checkpoint) checkpoint_path = TrainableUtil.process_checkpoint( checkpoint, parent_dir, state) return checkpoint_path