Exemple #1
0
    def _maybe_load_from_cloud(self, checkpoint_path: str) -> bool:
        if not self.uses_cloud_checkpointing:
            return False

        rel_checkpoint_dir = TrainableUtil.find_rel_checkpoint_dir(
            self.logdir, checkpoint_path
        )
        external_uri = os.path.join(self.remote_checkpoint_dir, rel_checkpoint_dir)
        local_dir = os.path.join(self.logdir, rel_checkpoint_dir)

        if self.custom_syncer:
            # Only keep for backwards compatibility
            self.custom_syncer.sync_down(remote_dir=external_uri, local_dir=local_dir)
            self.custom_syncer.wait_or_retry()
            return True

        checkpoint = Checkpoint.from_uri(external_uri)
        retry_fn(
            lambda: checkpoint.to_directory(local_dir),
            subprocess.CalledProcessError,
            num_retries=3,
            sleep_time=1,
        )

        return True
Exemple #2
0
    def restore(self, checkpoint_path):
        """Restores training state from a given model checkpoint.

        These checkpoints are returned from calls to save().

        Subclasses should override ``load_checkpoint()`` instead to
        restore state.
        This method restores additional metadata saved with the checkpoint.

        `checkpoint_path` should match with the return from ``save()``.

        `checkpoint_path` can be
        `~/ray_results/exp/MyTrainable_abc/
        checkpoint_00000/checkpoint`. Or,
        `~/ray_results/exp/MyTrainable_abc/checkpoint_00000`.

        `self.logdir` should generally be corresponding to `checkpoint_path`,
        for example, `~/ray_results/exp/MyTrainable_abc`.

        `self.remote_checkpoint_dir` in this case, is something like,
        `REMOTE_CHECKPOINT_BUCKET/exp/MyTrainable_abc`
        """
        if self.uses_cloud_checkpointing:
            rel_checkpoint_dir = TrainableUtil.find_rel_checkpoint_dir(
                self.logdir, checkpoint_path)
            self.storage_client.sync_down(
                os.path.join(self.remote_checkpoint_dir, rel_checkpoint_dir),
                os.path.join(self.logdir, rel_checkpoint_dir),
            )
            self.storage_client.wait_or_retry()

        # Ensure TrialCheckpoints are converted
        if isinstance(checkpoint_path, TrialCheckpoint):
            checkpoint_path = checkpoint_path.local_path

        with open(checkpoint_path + ".tune_metadata", "rb") as f:
            metadata = pickle.load(f)
        self._experiment_id = metadata["experiment_id"]
        self._iteration = metadata["iteration"]
        self._timesteps_total = metadata["timesteps_total"]
        self._time_total = metadata["time_total"]
        self._episodes_total = metadata["episodes_total"]
        saved_as_dict = metadata["saved_as_dict"]
        if saved_as_dict:
            with open(checkpoint_path, "rb") as loaded_state:
                checkpoint_dict = pickle.load(loaded_state)
            checkpoint_dict.update(tune_checkpoint_path=checkpoint_path)
            self.load_checkpoint(checkpoint_dict)
        else:
            self.load_checkpoint(checkpoint_path)
        self._time_since_restore = 0.0
        self._timesteps_since_restore = 0
        self._iterations_since_restore = 0
        self._restored = True
        logger.info("Restored on %s from checkpoint: %s",
                    self.get_current_ip(), checkpoint_path)
        state = {
            "_iteration": self._iteration,
            "_timesteps_total": self._timesteps_total,
            "_time_total": self._time_total,
            "_episodes_total": self._episodes_total,
        }
        logger.info("Current state after restoring: %s", state)
def test_find_rel_checkpoint_dir(checkpoint_path, logdir):
    assert (TrainableUtil.find_rel_checkpoint_dir(
        logdir, checkpoint_path) == "checkpoint0/")
Exemple #4
0
    def restore(self, checkpoint_path: str, checkpoint_node_ip: Optional[str] = None):
        """Restores training state from a given model checkpoint.

        These checkpoints are returned from calls to save().

        Subclasses should override ``load_checkpoint()`` instead to
        restore state.
        This method restores additional metadata saved with the checkpoint.

        `checkpoint_path` should match with the return from ``save()``.

        `checkpoint_path` can be
        `~/ray_results/exp/MyTrainable_abc/
        checkpoint_00000/checkpoint`. Or,
        `~/ray_results/exp/MyTrainable_abc/checkpoint_00000`.

        `self.logdir` should generally be corresponding to `checkpoint_path`,
        for example, `~/ray_results/exp/MyTrainable_abc`.

        `self.remote_checkpoint_dir` in this case, is something like,
        `REMOTE_CHECKPOINT_BUCKET/exp/MyTrainable_abc`

        Args:
            checkpoint_path: Path to restore checkpoint from. If this
                path does not exist on the local node, it will be fetched
                from external (cloud) storage if available, or restored
                from a remote node.
            checkpoint_node_ip: If given, try to restore
                checkpoint from this node if it doesn't exist locally or
                on cloud storage.

        """
        # Ensure TrialCheckpoints are converted
        if isinstance(checkpoint_path, TrialCheckpoint):
            checkpoint_path = checkpoint_path.local_path

        if self.uses_cloud_checkpointing:
            rel_checkpoint_dir = TrainableUtil.find_rel_checkpoint_dir(
                self.logdir, checkpoint_path
            )
            external_uri = os.path.join(self.remote_checkpoint_dir, rel_checkpoint_dir)
            local_dir = os.path.join(self.logdir, rel_checkpoint_dir)

            if self.storage_client:
                # Only keep for backwards compatibility
                self.storage_client.sync_down(external_uri, local_dir)
                self.storage_client.wait_or_retry()
            else:
                checkpoint = Checkpoint.from_uri(external_uri)
                retry_fn(
                    lambda: checkpoint.to_directory(local_dir),
                    subprocess.CalledProcessError,
                    num_retries=3,
                    sleep_time=1,
                )
        elif (
            # If a checkpoint source IP is given
            checkpoint_node_ip
            # And the checkpoint does not currently exist on the local node
            and not os.path.exists(checkpoint_node_ip)
            # And the source IP is different to the current IP
            and checkpoint_node_ip != ray.util.get_node_ip_address()
        ):
            checkpoint = get_checkpoint_from_remote_node(
                checkpoint_path, checkpoint_node_ip
            )
            if checkpoint:
                checkpoint.to_directory(checkpoint_path)

        with open(checkpoint_path + ".tune_metadata", "rb") as f:
            metadata = pickle.load(f)
        self._experiment_id = metadata["experiment_id"]
        self._iteration = metadata["iteration"]
        self._timesteps_total = metadata["timesteps_total"]
        self._time_total = metadata["time_total"]
        self._episodes_total = metadata["episodes_total"]
        saved_as_dict = metadata["saved_as_dict"]
        if saved_as_dict:
            with open(checkpoint_path, "rb") as loaded_state:
                checkpoint_dict = pickle.load(loaded_state)
            checkpoint_dict.update(tune_checkpoint_path=checkpoint_path)
            self.load_checkpoint(checkpoint_dict)
        else:
            self.load_checkpoint(checkpoint_path)
        self._time_since_restore = 0.0
        self._timesteps_since_restore = 0
        self._iterations_since_restore = 0
        self._restored = True
        logger.info(
            "Restored on %s from checkpoint: %s", self.get_current_ip(), checkpoint_path
        )
        state = {
            "_iteration": self._iteration,
            "_timesteps_total": self._timesteps_total,
            "_time_total": self._time_total,
            "_episodes_total": self._episodes_total,
        }
        logger.info("Current state after restoring: %s", state)