Esempio n. 1
0
    def testUriCheckpointSerde(self):
        # URI checkpoints keep the same internal representation, pointing to
        # a remote location

        checkpoint = Checkpoint.from_uri("s3://some/bucket")

        self._testCheckpointSerde(checkpoint, *checkpoint.get_internal_representation())
Esempio n. 2
0
    def get_best_checkpoint(
            self,
            trial: Trial,
            metric: Optional[str] = None,
            mode: Optional[str] = None) -> Optional[Checkpoint]:
        """Gets best persistent checkpoint path of provided trial.

        Any checkpoints with an associated metric value of ``nan`` will be filtered out.

        Args:
            trial: The log directory of a trial, or a trial instance.
            metric: key of trial info to return, e.g. "mean_accuracy".
                "training_iteration" is used by default if no value was
                passed to ``self.default_metric``.
            mode: One of [min, max]. Defaults to ``self.default_mode``.

        Returns:
            :class:`Checkpoint <ray.ml.Checkpoint>` object.
        """
        metric = metric or self.default_metric or TRAINING_ITERATION
        mode = self._validate_mode(mode)

        checkpoint_paths = self.get_trial_checkpoints_paths(trial, metric)

        # Filter out nan. Sorting nan values leads to undefined behavior.
        checkpoint_paths = [(path, metric) for path, metric in checkpoint_paths
                            if not is_nan(metric)]

        if not checkpoint_paths:
            logger.error(f"No checkpoints have been found for trial {trial}.")
            return None

        a = -1 if mode == "max" else 1
        best_path_metrics = sorted(checkpoint_paths, key=lambda x: a * x[1])

        best_path, best_metric = best_path_metrics[0]
        cloud_path = self._parse_cloud_path(best_path)

        if self._legacy_checkpoint:
            return TrialCheckpoint(local_path=best_path, cloud_path=cloud_path)

        if cloud_path:
            # Prefer cloud path over local path for downsteam processing
            return Checkpoint.from_uri(cloud_path)
        elif os.path.exists(best_path):
            return Checkpoint.from_directory(best_path)
        else:
            logger.error(
                f"No checkpoint locations for {trial} available on "
                f"this node. To avoid this, you "
                f"should enable checkpoint synchronization with the"
                f"`sync_config` argument in Ray Tune. "
                f"The checkpoint may be available on a different node - "
                f"please check this location on worker nodes: {best_path}")
            return None
Esempio n. 3
0
    def test_fs_delete_at_uri(self):
        """Test that clear bucket utility works"""
        checkpoint = self._prepare_fs_checkpoint()

        # Convert into dict checkpoint
        location = checkpoint.to_uri(self.cloud_uri)
        delete_at_uri(location)

        checkpoint = Checkpoint.from_uri(location)
        with self.assertRaises(FileNotFoundError):
            checkpoint.to_directory()
Esempio n. 4
0
    def test_fs_checkpoint_uri(self):
        """Test conversion from fs to cloud checkpoint and back."""
        checkpoint = self._prepare_fs_checkpoint()

        # Convert into dict checkpoint
        location = checkpoint.to_uri(self.cloud_uri)
        self.assertIsInstance(location, str)
        self.assertIn("memory://", location)

        # Create from dict
        checkpoint = Checkpoint.from_uri(location)
        self.assertTrue(checkpoint._uri)

        self._assert_fs_checkpoint(checkpoint)
Esempio n. 5
0
    def test_fs_checkpoint_uri(self):
        """Test conversion from fs to cloud checkpoint and back."""
        checkpoint = self._prepare_fs_checkpoint()

        with patch("subprocess.check_call", self.mock_s3):
            # Convert into dict checkpoint
            location = checkpoint.to_uri(self.cloud_uri)
            self.assertIsInstance(location, str)
            self.assertIn("s3://", location)

            # Create from dict
            checkpoint = Checkpoint.from_uri(location)
            self.assertTrue(checkpoint._uri)

            self._assert_fs_checkpoint(checkpoint)
Esempio n. 6
0
    def test_fs_checkpoint_uri_pa(self):
        """Test conversion from fs to cloud checkpoint and back."""
        checkpoint = self._prepare_fs_checkpoint()

        # Clean up mock bucket
        delete_at_uri(self.cloud_uri_pa)
        _ensure_directory(self.cloud_uri_pa)

        # Convert into dict checkpoint
        location = checkpoint.to_uri(self.cloud_uri_pa)
        self.assertIsInstance(location, str)
        self.assertIn("mock://", location)

        # Create from dict
        checkpoint = Checkpoint.from_uri(location)
        self.assertTrue(checkpoint._uri)

        self._assert_fs_checkpoint(checkpoint)
Esempio n. 7
0
    def restore(self, checkpoint_path: str, checkpoint_node_ip: Optional[str] = None):
        """Restores training state from a given model checkpoint.

        These checkpoints are returned from calls to save().

        Subclasses should override ``load_checkpoint()`` instead to
        restore state.
        This method restores additional metadata saved with the checkpoint.

        `checkpoint_path` should match with the return from ``save()``.

        `checkpoint_path` can be
        `~/ray_results/exp/MyTrainable_abc/
        checkpoint_00000/checkpoint`. Or,
        `~/ray_results/exp/MyTrainable_abc/checkpoint_00000`.

        `self.logdir` should generally be corresponding to `checkpoint_path`,
        for example, `~/ray_results/exp/MyTrainable_abc`.

        `self.remote_checkpoint_dir` in this case, is something like,
        `REMOTE_CHECKPOINT_BUCKET/exp/MyTrainable_abc`

        Args:
            checkpoint_path: Path to restore checkpoint from. If this
                path does not exist on the local node, it will be fetched
                from external (cloud) storage if available, or restored
                from a remote node.
            checkpoint_node_ip: If given, try to restore
                checkpoint from this node if it doesn't exist locally or
                on cloud storage.

        """
        # Ensure TrialCheckpoints are converted
        if isinstance(checkpoint_path, TrialCheckpoint):
            checkpoint_path = checkpoint_path.local_path

        if self.uses_cloud_checkpointing:
            rel_checkpoint_dir = TrainableUtil.find_rel_checkpoint_dir(
                self.logdir, checkpoint_path
            )
            external_uri = os.path.join(self.remote_checkpoint_dir, rel_checkpoint_dir)
            local_dir = os.path.join(self.logdir, rel_checkpoint_dir)

            if self.storage_client:
                # Only keep for backwards compatibility
                self.storage_client.sync_down(external_uri, local_dir)
                self.storage_client.wait_or_retry()
            else:
                checkpoint = Checkpoint.from_uri(external_uri)
                retry_fn(
                    lambda: checkpoint.to_directory(local_dir),
                    subprocess.CalledProcessError,
                    num_retries=3,
                    sleep_time=1,
                )
        elif (
            # If a checkpoint source IP is given
            checkpoint_node_ip
            # And the checkpoint does not currently exist on the local node
            and not os.path.exists(checkpoint_node_ip)
            # And the source IP is different to the current IP
            and checkpoint_node_ip != ray.util.get_node_ip_address()
        ):
            checkpoint = get_checkpoint_from_remote_node(
                checkpoint_path, checkpoint_node_ip
            )
            if checkpoint:
                checkpoint.to_directory(checkpoint_path)

        with open(checkpoint_path + ".tune_metadata", "rb") as f:
            metadata = pickle.load(f)
        self._experiment_id = metadata["experiment_id"]
        self._iteration = metadata["iteration"]
        self._timesteps_total = metadata["timesteps_total"]
        self._time_total = metadata["time_total"]
        self._episodes_total = metadata["episodes_total"]
        saved_as_dict = metadata["saved_as_dict"]
        if saved_as_dict:
            with open(checkpoint_path, "rb") as loaded_state:
                checkpoint_dict = pickle.load(loaded_state)
            checkpoint_dict.update(tune_checkpoint_path=checkpoint_path)
            self.load_checkpoint(checkpoint_dict)
        else:
            self.load_checkpoint(checkpoint_path)
        self._time_since_restore = 0.0
        self._timesteps_since_restore = 0
        self._iterations_since_restore = 0
        self._restored = True
        logger.info(
            "Restored on %s from checkpoint: %s", self.get_current_ip(), checkpoint_path
        )
        state = {
            "_iteration": self._iteration,
            "_timesteps_total": self._timesteps_total,
            "_time_total": self._time_total,
            "_episodes_total": self._episodes_total,
        }
        logger.info("Current state after restoring: %s", state)