示例#1
0
 def _priority(self, checkpoint):
     result = flatten_dict(checkpoint.result)
     priority = result[self._checkpoint_score_attr]
     if self._checkpoint_score_desc:
         priority = -priority
     return (
         not is_nan(priority),
         priority if not is_nan(priority) else 0,
         checkpoint.order,
     )
示例#2
0
    def testGetBestCheckpointNan(self):
        """Tests if nan values are excluded from best checkpoint."""
        metric = "loss"

        def train(config):
            for i in range(config["steps"]):
                if i == 0:
                    value = float("nan")
                else:
                    value = i
                result = {metric: value}
                with tune.checkpoint_dir(step=i):
                    pass
                tune.report(**result)

        ea = tune.run(train, local_dir=self.test_dir, config={"steps": 3})
        best_trial = ea.get_best_trial(metric, mode="min")
        best_checkpoint = ea.get_best_checkpoint(best_trial, metric, mode="min")
        checkpoints_metrics = ea.get_trial_checkpoints_paths(best_trial, metric=metric)
        expected_checkpoint_no_nan = min(
            [
                checkpoint_metric
                for checkpoint_metric in checkpoints_metrics
                if not is_nan(checkpoint_metric[1])
            ],
            key=lambda x: x[1],
        )[0]
        assert best_checkpoint == expected_checkpoint_no_nan
示例#3
0
    def get_best_checkpoint(
            self,
            trial: Trial,
            metric: Optional[str] = None,
            mode: Optional[str] = None) -> Optional[Checkpoint]:
        """Gets best persistent checkpoint path of provided trial.

        Any checkpoints with an associated metric value of ``nan`` will be filtered out.

        Args:
            trial: The log directory of a trial, or a trial instance.
            metric: key of trial info to return, e.g. "mean_accuracy".
                "training_iteration" is used by default if no value was
                passed to ``self.default_metric``.
            mode: One of [min, max]. Defaults to ``self.default_mode``.

        Returns:
            :class:`Checkpoint <ray.ml.Checkpoint>` object.
        """
        metric = metric or self.default_metric or TRAINING_ITERATION
        mode = self._validate_mode(mode)

        checkpoint_paths = self.get_trial_checkpoints_paths(trial, metric)

        # Filter out nan. Sorting nan values leads to undefined behavior.
        checkpoint_paths = [(path, metric) for path, metric in checkpoint_paths
                            if not is_nan(metric)]

        if not checkpoint_paths:
            logger.error(f"No checkpoints have been found for trial {trial}.")
            return None

        a = -1 if mode == "max" else 1
        best_path_metrics = sorted(checkpoint_paths, key=lambda x: a * x[1])

        best_path, best_metric = best_path_metrics[0]
        cloud_path = self._parse_cloud_path(best_path)

        if self._legacy_checkpoint:
            return TrialCheckpoint(local_path=best_path, cloud_path=cloud_path)

        if cloud_path:
            # Prefer cloud path over local path for downsteam processing
            return Checkpoint.from_uri(cloud_path)
        elif os.path.exists(best_path):
            return Checkpoint.from_directory(best_path)
        else:
            logger.error(
                f"No checkpoint locations for {trial} available on "
                f"this node. To avoid this, you "
                f"should enable checkpoint synchronization with the"
                f"`sync_config` argument in Ray Tune. "
                f"The checkpoint may be available on a different node - "
                f"please check this location on worker nodes: {best_path}")
            return None
示例#4
0
    def get_best_checkpoint(
        self,
        trial: Trial,
        metric: Optional[str] = None,
        mode: Optional[str] = None,
        return_path: bool = False,
    ) -> Optional[Union[Checkpoint, str]]:
        """Gets best persistent checkpoint path of provided trial.

        Any checkpoints with an associated metric value of ``nan`` will be filtered out.

        Args:
            trial: The log directory of a trial, or a trial instance.
            metric: key of trial info to return, e.g. "mean_accuracy".
                "training_iteration" is used by default if no value was
                passed to ``self.default_metric``.
            mode: One of [min, max]. Defaults to ``self.default_mode``.
            return_path: If True, only returns the path (and not the
                ``Checkpoint`` object). If using Ray client, it is not
                guaranteed that this path is available on the local
                (client) node. Can also contain a cloud URI.

        Returns:
            :class:`Checkpoint <ray.air.Checkpoint>` object or string
            if ``return_path=True``.
        """
        metric = metric or self.default_metric or TRAINING_ITERATION
        mode = self._validate_mode(mode)

        checkpoint_paths = self.get_trial_checkpoints_paths(trial, metric)

        # Filter out nan. Sorting nan values leads to undefined behavior.
        checkpoint_paths = [(path, metric) for path, metric in checkpoint_paths
                            if not is_nan(metric)]

        if not checkpoint_paths:
            logger.error(f"No checkpoints have been found for trial {trial}.")
            return None

        a = -1 if mode == "max" else 1
        best_path_metrics = sorted(checkpoint_paths, key=lambda x: a * x[1])

        best_path, best_metric = best_path_metrics[0]
        cloud_path = self._parse_cloud_path(best_path)

        if cloud_path:
            # Prefer cloud path over local path for downsteam processing
            if return_path:
                return cloud_path
            return Checkpoint.from_uri(cloud_path)
        elif os.path.exists(best_path):
            if return_path:
                return best_path
            return Checkpoint.from_directory(best_path)
        else:
            if log_once("checkpoint_not_available"):
                logger.error(
                    f"The requested checkpoint for trial {trial} is not available on "
                    f"this node, most likely because you are using Ray client or "
                    f"disabled checkpoint synchronization. To avoid this, enable "
                    f"checkpoint synchronization to cloud storage by specifying a "
                    f"`SyncConfig`. The checkpoint may be available on a different "
                    f"node - please check this location on worker nodes: {best_path}"
                )
            if return_path:
                return best_path
            return None