Ejemplo n.º 1
0
def _load_checkpoint(checkpoint: Union["Checkpoint", str], ) -> "Checkpoint":
    from ray.air.checkpoint import Checkpoint

    if isinstance(checkpoint, str):
        checkpoint = Checkpoint.from_uri(checkpoint)
    assert isinstance(checkpoint, Checkpoint)
    return checkpoint
Ejemplo n.º 2
0
    def _maybe_load_from_cloud(self, checkpoint_path: str) -> bool:
        if not self.uses_cloud_checkpointing:
            return False

        rel_checkpoint_dir = TrainableUtil.find_rel_checkpoint_dir(
            self.logdir, checkpoint_path)
        external_uri = os.path.join(self.remote_checkpoint_dir,
                                    rel_checkpoint_dir)
        local_dir = os.path.join(self.logdir, rel_checkpoint_dir)

        if self.custom_syncer:
            # Only keep for backwards compatibility
            self.custom_syncer.sync_down(remote_dir=external_uri,
                                         local_dir=local_dir)
            self.custom_syncer.wait_or_retry()
            return True

        checkpoint = Checkpoint.from_uri(external_uri)
        retry_fn(
            lambda: checkpoint.to_directory(local_dir),
            subprocess.CalledProcessError,
            num_retries=3,
            sleep_time=1,
        )

        return True
Ejemplo n.º 3
0
def _load_checkpoint(
    checkpoint: Union[Checkpoint, str],
) -> Checkpoint:
    if isinstance(checkpoint, str):
        checkpoint = Checkpoint.from_uri(checkpoint)
    assert isinstance(checkpoint, Checkpoint)
    return checkpoint
Ejemplo n.º 4
0
    def testUriCheckpointSerde(self):
        # URI checkpoints keep the same internal representation, pointing to
        # a remote location

        checkpoint = Checkpoint.from_uri("s3://some/bucket")

        self._testCheckpointSerde(checkpoint, *checkpoint.get_internal_representation())
Ejemplo n.º 5
0
    def get_best_checkpoint(
            self,
            trial: Trial,
            metric: Optional[str] = None,
            mode: Optional[str] = None) -> Optional[Checkpoint]:
        """Gets best persistent checkpoint path of provided trial.

        Any checkpoints with an associated metric value of ``nan`` will be filtered out.

        Args:
            trial: The log directory of a trial, or a trial instance.
            metric: key of trial info to return, e.g. "mean_accuracy".
                "training_iteration" is used by default if no value was
                passed to ``self.default_metric``.
            mode: One of [min, max]. Defaults to ``self.default_mode``.

        Returns:
            :class:`Checkpoint <ray.air.Checkpoint>` object.
        """
        metric = metric or self.default_metric or TRAINING_ITERATION
        mode = self._validate_mode(mode)

        checkpoint_paths = self.get_trial_checkpoints_paths(trial, metric)

        # Filter out nan. Sorting nan values leads to undefined behavior.
        checkpoint_paths = [(path, metric) for path, metric in checkpoint_paths
                            if not is_nan(metric)]

        if not checkpoint_paths:
            logger.error(f"No checkpoints have been found for trial {trial}.")
            return None

        a = -1 if mode == "max" else 1
        best_path_metrics = sorted(checkpoint_paths, key=lambda x: a * x[1])

        best_path, best_metric = best_path_metrics[0]
        cloud_path = self._parse_cloud_path(best_path)

        if self._legacy_checkpoint:
            return TrialCheckpoint(local_path=best_path, cloud_path=cloud_path)

        if cloud_path:
            # Prefer cloud path over local path for downsteam processing
            return Checkpoint.from_uri(cloud_path)
        elif os.path.exists(best_path):
            return Checkpoint.from_directory(best_path)
        else:
            logger.error(
                f"No checkpoint locations for {trial} available on "
                f"this node. To avoid this, you "
                f"should enable checkpoint synchronization with the"
                f"`sync_config` argument in Ray Tune. "
                f"The checkpoint may be available on a different node - "
                f"please check this location on worker nodes: {best_path}")
            return None
Ejemplo n.º 6
0
    def test_fs_delete_at_uri(self):
        """Test that clear bucket utility works"""
        checkpoint = self._prepare_fs_checkpoint()

        # Convert into dict checkpoint
        location = checkpoint.to_uri(self.cloud_uri)
        delete_at_uri(location)

        checkpoint = Checkpoint.from_uri(location)
        with self.assertRaises(FileNotFoundError):
            checkpoint.to_directory()
Ejemplo n.º 7
0
    def test_fs_checkpoint_uri(self):
        """Test conversion from fs to cloud checkpoint and back."""
        checkpoint = self._prepare_fs_checkpoint()

        # Convert into dict checkpoint
        location = checkpoint.to_uri(self.cloud_uri)
        self.assertIsInstance(location, str)
        self.assertIn("memory://", location)

        # Create from dict
        checkpoint = Checkpoint.from_uri(location)
        self.assertTrue(checkpoint._uri)

        self._assert_fs_checkpoint(checkpoint)
Ejemplo n.º 8
0
    def test_fs_checkpoint_uri_pa(self):
        """Test conversion from fs to cloud checkpoint and back."""
        checkpoint = self._prepare_fs_checkpoint()

        # Clean up mock bucket
        delete_at_uri(self.cloud_uri_pa)
        _ensure_directory(self.cloud_uri_pa)

        # Convert into dict checkpoint
        location = checkpoint.to_uri(self.cloud_uri_pa)
        self.assertIsInstance(location, str)
        self.assertIn("mock://", location)

        # Create from dict
        checkpoint = Checkpoint.from_uri(location)
        self.assertTrue(checkpoint._uri)

        self._assert_fs_checkpoint(checkpoint)
Ejemplo n.º 9
0
    def get_best_checkpoint(
        self,
        trial: Trial,
        metric: Optional[str] = None,
        mode: Optional[str] = None,
        return_path: bool = False,
    ) -> Optional[Union[Checkpoint, str]]:
        """Gets best persistent checkpoint path of provided trial.

        Any checkpoints with an associated metric value of ``nan`` will be filtered out.

        Args:
            trial: The log directory of a trial, or a trial instance.
            metric: key of trial info to return, e.g. "mean_accuracy".
                "training_iteration" is used by default if no value was
                passed to ``self.default_metric``.
            mode: One of [min, max]. Defaults to ``self.default_mode``.
            return_path: If True, only returns the path (and not the
                ``Checkpoint`` object). If using Ray client, it is not
                guaranteed that this path is available on the local
                (client) node. Can also contain a cloud URI.

        Returns:
            :class:`Checkpoint <ray.air.Checkpoint>` object or string
            if ``return_path=True``.
        """
        metric = metric or self.default_metric or TRAINING_ITERATION
        mode = self._validate_mode(mode)

        checkpoint_paths = self.get_trial_checkpoints_paths(trial, metric)

        # Filter out nan. Sorting nan values leads to undefined behavior.
        checkpoint_paths = [(path, metric) for path, metric in checkpoint_paths
                            if not is_nan(metric)]

        if not checkpoint_paths:
            logger.error(f"No checkpoints have been found for trial {trial}.")
            return None

        a = -1 if mode == "max" else 1
        best_path_metrics = sorted(checkpoint_paths, key=lambda x: a * x[1])

        best_path, best_metric = best_path_metrics[0]
        cloud_path = self._parse_cloud_path(best_path)

        if cloud_path:
            # Prefer cloud path over local path for downsteam processing
            if return_path:
                return cloud_path
            return Checkpoint.from_uri(cloud_path)
        elif os.path.exists(best_path):
            if return_path:
                return best_path
            return Checkpoint.from_directory(best_path)
        else:
            if log_once("checkpoint_not_available"):
                logger.error(
                    f"The requested checkpoint for trial {trial} is not available on "
                    f"this node, most likely because you are using Ray client or "
                    f"disabled checkpoint synchronization. To avoid this, enable "
                    f"checkpoint synchronization to cloud storage by specifying a "
                    f"`SyncConfig`. The checkpoint may be available on a different "
                    f"node - please check this location on worker nodes: {best_path}"
                )
            if return_path:
                return best_path
            return None
Ejemplo n.º 10
0
    def restore(self, checkpoint_path: str, checkpoint_node_ip: Optional[str] = None):
        """Restores training state from a given model checkpoint.

        These checkpoints are returned from calls to save().

        Subclasses should override ``load_checkpoint()`` instead to
        restore state.
        This method restores additional metadata saved with the checkpoint.

        `checkpoint_path` should match with the return from ``save()``.

        `checkpoint_path` can be
        `~/ray_results/exp/MyTrainable_abc/
        checkpoint_00000/checkpoint`. Or,
        `~/ray_results/exp/MyTrainable_abc/checkpoint_00000`.

        `self.logdir` should generally be corresponding to `checkpoint_path`,
        for example, `~/ray_results/exp/MyTrainable_abc`.

        `self.remote_checkpoint_dir` in this case, is something like,
        `REMOTE_CHECKPOINT_BUCKET/exp/MyTrainable_abc`

        Args:
            checkpoint_path: Path to restore checkpoint from. If this
                path does not exist on the local node, it will be fetched
                from external (cloud) storage if available, or restored
                from a remote node.
            checkpoint_node_ip: If given, try to restore
                checkpoint from this node if it doesn't exist locally or
                on cloud storage.

        """
        # Ensure TrialCheckpoints are converted
        if isinstance(checkpoint_path, TrialCheckpoint):
            checkpoint_path = checkpoint_path.local_path

        if self.uses_cloud_checkpointing:
            rel_checkpoint_dir = TrainableUtil.find_rel_checkpoint_dir(
                self.logdir, checkpoint_path
            )
            external_uri = os.path.join(self.remote_checkpoint_dir, rel_checkpoint_dir)
            local_dir = os.path.join(self.logdir, rel_checkpoint_dir)

            if self.storage_client:
                # Only keep for backwards compatibility
                self.storage_client.sync_down(external_uri, local_dir)
                self.storage_client.wait_or_retry()
            else:
                checkpoint = Checkpoint.from_uri(external_uri)
                retry_fn(
                    lambda: checkpoint.to_directory(local_dir),
                    subprocess.CalledProcessError,
                    num_retries=3,
                    sleep_time=1,
                )
        elif (
            # If a checkpoint source IP is given
            checkpoint_node_ip
            # And the checkpoint does not currently exist on the local node
            and not os.path.exists(checkpoint_node_ip)
            # And the source IP is different to the current IP
            and checkpoint_node_ip != ray.util.get_node_ip_address()
        ):
            checkpoint = get_checkpoint_from_remote_node(
                checkpoint_path, checkpoint_node_ip
            )
            if checkpoint:
                checkpoint.to_directory(checkpoint_path)

        with open(checkpoint_path + ".tune_metadata", "rb") as f:
            metadata = pickle.load(f)
        self._experiment_id = metadata["experiment_id"]
        self._iteration = metadata["iteration"]
        self._timesteps_total = metadata["timesteps_total"]
        self._time_total = metadata["time_total"]
        self._episodes_total = metadata["episodes_total"]
        saved_as_dict = metadata["saved_as_dict"]
        if saved_as_dict:
            with open(checkpoint_path, "rb") as loaded_state:
                checkpoint_dict = pickle.load(loaded_state)
            checkpoint_dict.update(tune_checkpoint_path=checkpoint_path)
            self.load_checkpoint(checkpoint_dict)
        else:
            self.load_checkpoint(checkpoint_path)
        self._time_since_restore = 0.0
        self._timesteps_since_restore = 0
        self._iterations_since_restore = 0
        self._restored = True
        logger.info(
            "Restored on %s from checkpoint: %s", self.get_current_ip(), checkpoint_path
        )
        state = {
            "_iteration": self._iteration,
            "_timesteps_total": self._timesteps_total,
            "_time_total": self._time_total,
            "_episodes_total": self._episodes_total,
        }
        logger.info("Current state after restoring: %s", state)