示例#1
0
    def testFindCheckpointDir(self):
        checkpoint_path = os.path.join(self.checkpoint_dir, "my/nested/chkpt")
        os.makedirs(checkpoint_path)
        found_dir = TrainableUtil.find_checkpoint_dir(checkpoint_path)
        self.assertEquals(self.checkpoint_dir, found_dir)

        with self.assertRaises(FileNotFoundError):
            parent = os.path.dirname(found_dir)
            TrainableUtil.find_checkpoint_dir(parent)
示例#2
0
    def delete_checkpoint(self, checkpoint_path):
        """Deletes local copy of checkpoint.

        Args:
            checkpoint_path (str): Path to checkpoint.
        """
        # Ensure TrialCheckpoints are converted
        if isinstance(checkpoint_path, TrialCheckpoint):
            checkpoint_path = checkpoint_path.local_path

        try:
            checkpoint_dir = TrainableUtil.find_checkpoint_dir(checkpoint_path)
        except FileNotFoundError:
            # The checkpoint won't exist locally if the
            # trial was rescheduled to another worker.
            logger.debug(
                f"Local checkpoint not found during garbage collection: "
                f"{self.trial_id} - {checkpoint_path}")
            return
        else:
            if self.uses_cloud_checkpointing:
                self.storage_client.delete(self._storage_path(checkpoint_dir))
                self.storage_client.wait_or_retry()

        if os.path.exists(checkpoint_dir):
            shutil.rmtree(checkpoint_dir)
示例#3
0
    def __call__(self, checkpoint: _TuneCheckpoint):
        """Requests checkpoint deletion asynchronously.

        Args:
            checkpoint: Checkpoint to delete.
        """
        if not self.runner:
            return

        if checkpoint.storage == _TuneCheckpoint.PERSISTENT and checkpoint.value:
            checkpoint_path = checkpoint.value

            logger.debug("Trial %s: Deleting checkpoint %s", self.trial_id,
                         checkpoint_path)

            # TODO(ujvl): Batch remote deletes.
            # We first delete the remote checkpoint. If it is on the same
            # node as the driver, it will also remove the local copy.
            ray.get(self.runner.delete_checkpoint.remote(checkpoint_path))

            # Delete local copy, if any exists.
            if os.path.exists(checkpoint_path):
                try:
                    checkpoint_dir = TrainableUtil.find_checkpoint_dir(
                        checkpoint_path)
                    shutil.rmtree(checkpoint_dir)
                except FileNotFoundError:
                    logger.debug(
                        "Local checkpoint dir not found during deletion.")
示例#4
0
    def __call__(self, checkpoint):
        """Requests checkpoint deletion asynchronously.

        Args:
            checkpoint (Checkpoint): Checkpoint to delete.
        """
        if not self.runner:
            return

        if checkpoint.storage == Checkpoint.PERSISTENT and checkpoint.value:
            logger.debug("Trial %s: Deleting checkpoint %s", self.trial_id,
                         checkpoint.value)
            checkpoint_path = checkpoint.value

            # Delete local copy, if any exists.
            if self.runner_ip != self.node_ip and os.path.exists(
                    checkpoint_path):
                try:
                    checkpoint_dir = TrainableUtil.find_checkpoint_dir(
                        checkpoint_path)
                    shutil.rmtree(checkpoint_dir)
                except FileNotFoundError:
                    logger.warning("Checkpoint dir not found during deletion.")

            # TODO(ujvl): Batch remote deletes.
            self.runner.delete_checkpoint.remote(checkpoint.value)
示例#5
0
    def to_air_checkpoint(self) -> Optional[Checkpoint]:
        checkpoint_data = self.dir_or_data

        if not checkpoint_data:
            return None

        if isinstance(checkpoint_data, ray.ObjectRef):
            checkpoint_data = ray.get(checkpoint_data)

        if isinstance(checkpoint_data, str):
            checkpoint_dir = TrainableUtil.find_checkpoint_dir(checkpoint_data)
            checkpoint = Checkpoint.from_directory(checkpoint_dir)
        elif isinstance(checkpoint_data, bytes):
            with tempfile.TemporaryDirectory() as tmpdir:
                TrainableUtil.create_from_pickle(checkpoint_data, tmpdir)
                # Double wrap in checkpoint so we hold the data in memory and
                # can remove the temp directory
                checkpoint = Checkpoint.from_dict(
                    Checkpoint.from_directory(tmpdir).to_dict())
        elif isinstance(checkpoint_data, dict):
            checkpoint = Checkpoint.from_dict(checkpoint_data)
        else:
            raise RuntimeError(
                f"Unknown checkpoint data type: {type(checkpoint_data)}")

        return checkpoint
示例#6
0
            def save_checkpoint(self, tmp_checkpoint_dir: str = ""):
                checkpoint_path = super().save_checkpoint()
                parent_dir = TrainableUtil.find_checkpoint_dir(checkpoint_path)

                preprocessor = self._merged_config.get("preprocessor", None)
                if parent_dir and preprocessor:
                    save_preprocessor_to_dir(preprocessor, parent_dir)
                return checkpoint_path
示例#7
0
            def save_checkpoint(self, tmp_checkpoint_dir: str = ""):
                checkpoint_path = super().save_checkpoint()
                parent_dir = TrainableUtil.find_checkpoint_dir(checkpoint_path)

                preprocessor = self._merged_config.get("preprocessor", None)
                if parent_dir and preprocessor:
                    with open(os.path.join(parent_dir, PREPROCESSOR_KEY),
                              "wb") as f:
                        cpickle.dump(preprocessor, f)
                return checkpoint_path
示例#8
0
    def testConvertTempToPermanent(self):
        checkpoint_dir = FuncCheckpointUtil.mk_temp_checkpoint_dir(self.logdir)
        new_checkpoint_dir = FuncCheckpointUtil.create_perm_checkpoint(
            checkpoint_dir, self.logdir, step=4)
        assert new_checkpoint_dir == TrainableUtil.find_checkpoint_dir(
            new_checkpoint_dir)
        assert os.path.exists(new_checkpoint_dir)
        assert not FuncCheckpointUtil.is_temp_checkpoint_dir(
            new_checkpoint_dir)

        tmp_checkpoint_dir = FuncCheckpointUtil.mk_temp_checkpoint_dir(
            self.logdir)
        assert tmp_checkpoint_dir != new_checkpoint_dir
示例#9
0
    def _trial_to_result(self, trial: Trial) -> Result:
        if trial.checkpoint.value:
            checkpoint_dir = TrainableUtil.find_checkpoint_dir(trial.checkpoint.value)
            checkpoint = Checkpoint.from_directory(checkpoint_dir)
        else:
            checkpoint = None

        result = Result(
            checkpoint=checkpoint,
            metrics=trial.last_result.copy(),
            error=self._populate_exception(trial),
        )
        return result
示例#10
0
    def delete_checkpoint(self, checkpoint_path):
        """Deletes local copy of checkpoint.

        Args:
            checkpoint_path (str): Path to checkpoint.
        """
        try:
            checkpoint_dir = TrainableUtil.find_checkpoint_dir(checkpoint_path)
        except FileNotFoundError:
            # The checkpoint won't exist locally if the
            # trial was rescheduled to another worker.
            logger.debug("Checkpoint not found during garbage collection.")
            return
        if os.path.exists(checkpoint_dir):
            shutil.rmtree(checkpoint_dir)
示例#11
0
def create_checkpoint(preprocessor: Optional[Preprocessor] = None,
                      config: Optional[dict] = None) -> Checkpoint:
    rl_trainer = RLTrainer(
        algorithm=_DummyAlgo,
        config=config or {},
        preprocessor=preprocessor,
    )
    rl_trainable_cls = rl_trainer.as_trainable()
    rl_trainable = rl_trainable_cls()

    with tempfile.TemporaryDirectory() as checkpoint_dir:
        checkpoint_file = rl_trainable.save(checkpoint_dir)
        checkpoint_path = TrainableUtil.find_checkpoint_dir(checkpoint_file)
        checkpoint_data = Checkpoint.from_directory(checkpoint_path).to_dict()

    return Checkpoint.from_dict(checkpoint_data)
示例#12
0
    def delete_checkpoint(self, checkpoint_path: str):
        """Deletes local copy of checkpoint.

        Args:
            checkpoint_path: Path to checkpoint.
        """
        # Ensure TrialCheckpoints are converted
        if isinstance(checkpoint_path, TrialCheckpoint):
            checkpoint_path = checkpoint_path.local_path

        try:
            checkpoint_dir = TrainableUtil.find_checkpoint_dir(checkpoint_path)
        except FileNotFoundError:
            # The checkpoint won't exist locally if the
            # trial was rescheduled to another worker.
            logger.debug(
                f"Local checkpoint not found during garbage collection: "
                f"{self.trial_id} - {checkpoint_path}"
            )
            return
        else:
            if self.uses_cloud_checkpointing:
                if self.custom_syncer:
                    # Keep for backwards compatibility
                    self.custom_syncer.delete(self._storage_path(checkpoint_dir))
                    self.custom_syncer.wait_or_retry()
                else:
                    checkpoint_uri = self._storage_path(checkpoint_dir)
                    retry_fn(
                        lambda: delete_external_checkpoint(checkpoint_uri),
                        subprocess.CalledProcessError,
                        num_retries=3,
                        sleep_time=1,
                    )

        if os.path.exists(checkpoint_dir):
            shutil.rmtree(checkpoint_dir)