Exemple #1
0
    def testFindCheckpointDir(self):
        checkpoint_path = os.path.join(self.checkpoint_dir, "my/nested/chkpt")
        os.makedirs(checkpoint_path)
        found_dir = TrainableUtil.find_checkpoint_dir(checkpoint_path)
        self.assertEquals(self.checkpoint_dir, found_dir)

        with self.assertRaises(FileNotFoundError):
            parent = os.path.dirname(found_dir)
            TrainableUtil.find_checkpoint_dir(parent)
Exemple #2
0
 def save_checkpoint(self, checkpoint):
     if isinstance(checkpoint, str):
         try:
             TrainableUtil.find_checkpoint_dir(checkpoint)
         except FileNotFoundError:
             logger.error("Checkpoint must be created with path given from "
                          "make_checkpoint_dir.")
             raise
     self._last_checkpoint = checkpoint
     self._fresh_checkpoint = True
Exemple #3
0
    def set_checkpoint(self, checkpoint, is_new=True):
        """Sets the checkpoint to be returned upon get_checkpoint.

        If this is a "new" checkpoint, it will notify Tune
        (via has_new_checkpoint). Otherwise, it will NOT notify Tune.
        """
        if isinstance(checkpoint, str):
            try:
                TrainableUtil.find_checkpoint_dir(checkpoint)
            except FileNotFoundError:
                logger.error("Checkpoint must be created with path given from "
                             "make_checkpoint_dir.")
                raise
        self._last_checkpoint = checkpoint
        if is_new:
            self._fresh_checkpoint = True
Exemple #4
0
    def save_checkpoint(self, checkpoint_dir: str = ""):
        if checkpoint_dir:
            raise ValueError(
                "Checkpoint dir should not be used with function API.")

        checkpoint = self._status_reporter.get_checkpoint()

        if not checkpoint:
            # We drop a marker here to indicate that the checkpoint is empty
            checkpoint = FuncCheckpointUtil.mk_null_checkpoint_dir(self.logdir)
            parent_dir = checkpoint
        elif isinstance(checkpoint, dict):
            return checkpoint
        elif isinstance(checkpoint, str):
            parent_dir = TrainableUtil.find_checkpoint_dir(checkpoint)
            # When the trainable is restored, a temporary checkpoint
            # is created. However, when saved, it should become permanent.
            # Ideally, there are no save calls upon a temporary
            # checkpoint, but certain schedulers might.
            if FuncCheckpointUtil.is_temp_checkpoint_dir(parent_dir):
                parent_dir = FuncCheckpointUtil.create_perm_checkpoint(
                    checkpoint_dir=parent_dir,
                    logdir=self.logdir,
                    step=self.training_iteration,
                )
        else:
            raise ValueError("Provided checkpoint was expected to have "
                             "type (str, dict). Got {}.".format(
                                 type(checkpoint)))

        return parent_dir
Exemple #5
0
    def delete_checkpoint(self, checkpoint_path):
        """Deletes checkpoint from both local and remote storage.

        Args:
            checkpoint_path (str): Local path to checkpoint.
        """
        super(DurableTrainable, self).delete_checkpoint(checkpoint_path)
        local_dirpath = TrainableUtil.find_checkpoint_dir(checkpoint_path)
        self.storage_client.delete(self._storage_path(local_dirpath))
Exemple #6
0
    def save(self, checkpoint_path=None) -> str:
        if checkpoint_path:
            raise ValueError("Checkpoint path should not be used with function API.")

        checkpoint_path = self.save_checkpoint()

        parent_dir = TrainableUtil.find_checkpoint_dir(checkpoint_path)
        self._maybe_save_to_cloud(parent_dir)

        return checkpoint_path
Exemple #7
0
def save_all_states_remote(self, trial_state):
    """ Save all of AdaptDL's job state and return it as an in-memory
    object."""
    checkpoint = save_all_states()
    parent_dir = TrainableUtil.find_checkpoint_dir(checkpoint)
    checkpoint_path = TrainableUtil.process_checkpoint(checkpoint,
                                                       parent_dir,
                                                       trial_state)
    checkpoint_obj = TrainableUtil.checkpoint_to_object(checkpoint_path)
    # Done with the directory, remove
    shutil.rmtree(checkpoint_path)
    return checkpoint_obj
Exemple #8
0
    def testConvertTempToPermanent(self):
        checkpoint_dir = FuncCheckpointUtil.mk_temp_checkpoint_dir(self.logdir)
        new_checkpoint_dir = FuncCheckpointUtil.create_perm_checkpoint(
            checkpoint_dir, self.logdir, step=4)
        assert new_checkpoint_dir == TrainableUtil.find_checkpoint_dir(
            new_checkpoint_dir)
        assert os.path.exists(new_checkpoint_dir)
        assert not FuncCheckpointUtil.is_temp_checkpoint_dir(
            new_checkpoint_dir)

        tmp_checkpoint_dir = FuncCheckpointUtil.mk_temp_checkpoint_dir(
            self.logdir)
        assert tmp_checkpoint_dir != new_checkpoint_dir
    def delete_checkpoint(self, checkpoint_path):
        """Deletes checkpoint from both local and remote storage.

        Args:
            checkpoint_path (str): Local path to checkpoint.
        """
        try:
            local_dirpath = TrainableUtil.find_checkpoint_dir(checkpoint_path)
        except FileNotFoundError:
            logger.warning(
                "Trial %s: checkpoint path not found during "
                "garbage collection. See issue #6697.", self.trial_id)
        else:
            self.storage_client.delete(self._storage_path(local_dirpath))
        super(DurableTrainable, self).delete_checkpoint(checkpoint_path)
Exemple #10
0
    def save(self, checkpoint_path=None) -> str:
        if checkpoint_path:
            raise ValueError("Checkpoint path should not be used with function API.")

        checkpoint = self._status_reporter.get_checkpoint()
        state = self.get_state()

        if not checkpoint:
            state.update(iteration=0, timesteps_total=0, episodes_total=0)
            # We drop a marker here to indicate that the checkpoint is empty
            checkpoint = FuncCheckpointUtil.mk_null_checkpoint_dir(self.logdir)
            parent_dir = checkpoint
        elif isinstance(checkpoint, dict):
            parent_dir = TrainableUtil.make_checkpoint_dir(
                self.logdir, index=self.training_iteration
            )
        elif isinstance(checkpoint, str):
            parent_dir = TrainableUtil.find_checkpoint_dir(checkpoint)
            # When the trainable is restored, a temporary checkpoint
            # is created. However, when saved, it should become permanent.
            # Ideally, there are no save calls upon a temporary
            # checkpoint, but certain schedulers might.
            if FuncCheckpointUtil.is_temp_checkpoint_dir(parent_dir):
                relative_path = os.path.relpath(checkpoint, parent_dir)
                parent_dir = FuncCheckpointUtil.create_perm_checkpoint(
                    checkpoint_dir=parent_dir,
                    logdir=self.logdir,
                    step=self.training_iteration,
                )
                checkpoint = os.path.abspath(os.path.join(parent_dir, relative_path))
        else:
            raise ValueError(
                "Provided checkpoint was expected to have "
                "type (str, dict). Got {}.".format(type(checkpoint))
            )

        checkpoint_path = TrainableUtil.process_checkpoint(
            checkpoint, parent_dir, state
        )

        self._postprocess_checkpoint(checkpoint_path)

        self._maybe_save_to_cloud(parent_dir)

        return checkpoint_path
Exemple #11
0
    def save(self, checkpoint_path=None):
        if checkpoint_path:
            raise ValueError(
                "Checkpoint path should not be used with function API.")

        checkpoint = self._status_reporter.get_checkpoint()
        state = self.get_state()

        if not checkpoint:
            state.update(iteration=0, timesteps_total=0, episodes_total=0)
            parent_dir = self.create_default_checkpoint_dir()
        elif isinstance(checkpoint, dict):
            parent_dir = TrainableUtil.make_checkpoint_dir(
                self.logdir, index=self.training_iteration)
        else:
            parent_dir = TrainableUtil.find_checkpoint_dir(checkpoint)
        checkpoint_path = TrainableUtil.process_checkpoint(
            checkpoint, parent_dir, state)
        return checkpoint_path
Exemple #12
0
    def delete(checkpoint):
        """Requests checkpoint deletion asynchronously.

        Args:
            checkpoint (Checkpoint): Checkpoint to delete.
        """
        if checkpoint.storage == Checkpoint.PERSISTENT and checkpoint.value:
            logger.debug("Trial %s: Deleting checkpoint %s", trial_id,
                         checkpoint.value)
            checkpoint_path = checkpoint.value
            # Delete local copy, if any exists.
            if os.path.exists(checkpoint_path):
                try:
                    checkpoint_dir = TrainableUtil.find_checkpoint_dir(
                        checkpoint_path)
                    shutil.rmtree(checkpoint_dir)
                except FileNotFoundError:
                    logger.warning("Checkpoint dir not found during deletion.")

            # TODO(ujvl): Batch remote deletes.
            runner.delete_checkpoint.remote(checkpoint.value)