예제 #1
0
    def testUploadDefaultBoth(self):
        other_local_dir = "/tmp/other"
        other_cloud_dir = "memory:///other"

        delete_at_uri(other_cloud_dir)
        self._save_checkpoint_at(other_cloud_dir)
        shutil.copytree(self.local_dir, other_local_dir)

        # Case: Nothing is passed
        checkpoint = TrialCheckpoint(local_path=self.local_dir,
                                     cloud_path=self.cloud_dir)

        path = checkpoint.upload()

        self.assertEqual(self.cloud_dir, path)

        # Case: Local dir is passed
        checkpoint = TrialCheckpoint(local_path=self.local_dir,
                                     cloud_path=self.cloud_dir)

        path = checkpoint.upload(local_path=other_local_dir)

        self.assertEqual(self.cloud_dir, path)

        # Case: Both are passed
        checkpoint = TrialCheckpoint(local_path=self.local_dir,
                                     cloud_path=self.cloud_dir)

        path = checkpoint.upload(local_path=other_local_dir,
                                 cloud_path=other_cloud_dir)

        self.assertEqual(other_cloud_dir, path)
예제 #2
0
    def testDownloadDefaultBoth(self):
        other_local_dir = "/tmp/other"
        other_cloud_dir = "memory:///other"

        self._save_checkpoint_at(other_cloud_dir)
        self._save_checkpoint_at(self.cloud_dir)

        # Case: Nothing is passed
        checkpoint = TrialCheckpoint(local_path=self.local_dir,
                                     cloud_path=self.cloud_dir)

        path = checkpoint.download()

        self.assertEqual(self.local_dir, path)

        # Case: Local dir is passed
        checkpoint = TrialCheckpoint(local_path=self.local_dir,
                                     cloud_path=self.cloud_dir)

        path = checkpoint.download(local_path=other_local_dir)

        self.assertEqual(other_local_dir, path)

        # Case: Both are passed
        checkpoint = TrialCheckpoint(local_path=self.local_dir,
                                     cloud_path=self.cloud_dir)

        path = checkpoint.download(local_path=other_local_dir,
                                   cloud_path=other_cloud_dir)

        self.assertEqual(other_local_dir, path)
예제 #3
0
    def get_best_checkpoint(
        self, trial: Trial, metric: Optional[str] = None, mode: Optional[str] = None
    ) -> Optional[TrialCheckpoint]:
        """Gets best persistent checkpoint path of provided trial.

        Args:
            trial (Trial): The log directory of a trial, or a trial instance.
            metric (str): key of trial info to return, e.g. "mean_accuracy".
                "training_iteration" is used by default if no value was
                passed to ``self.default_metric``.
            mode (str): One of [min, max]. Defaults to ``self.default_mode``.

        Returns:
            :class:`TrialCheckpoint <ray.tune.cloud.TrialCheckpoint>` object.
        """
        metric = metric or self.default_metric or TRAINING_ITERATION
        mode = self._validate_mode(mode)

        checkpoint_paths = self.get_trial_checkpoints_paths(trial, metric)
        if not checkpoint_paths:
            logger.error(f"No checkpoints have been found for trial {trial}.")
            return None

        a = -1 if mode == "max" else 1
        best_path_metrics = sorted(checkpoint_paths, key=lambda x: a * x[1])

        best_path, best_metric = best_path_metrics[0]
        return TrialCheckpoint(
            local_path=best_path, cloud_path=self._parse_cloud_path(best_path)
        )
예제 #4
0
    def get_best_checkpoint(
            self,
            trial: Trial,
            metric: Optional[str] = None,
            mode: Optional[str] = None) -> Optional[Checkpoint]:
        """Gets best persistent checkpoint path of provided trial.

        Any checkpoints with an associated metric value of ``nan`` will be filtered out.

        Args:
            trial: The log directory of a trial, or a trial instance.
            metric: key of trial info to return, e.g. "mean_accuracy".
                "training_iteration" is used by default if no value was
                passed to ``self.default_metric``.
            mode: One of [min, max]. Defaults to ``self.default_mode``.

        Returns:
            :class:`Checkpoint <ray.ml.Checkpoint>` object.
        """
        metric = metric or self.default_metric or TRAINING_ITERATION
        mode = self._validate_mode(mode)

        checkpoint_paths = self.get_trial_checkpoints_paths(trial, metric)

        # Filter out nan. Sorting nan values leads to undefined behavior.
        checkpoint_paths = [(path, metric) for path, metric in checkpoint_paths
                            if not is_nan(metric)]

        if not checkpoint_paths:
            logger.error(f"No checkpoints have been found for trial {trial}.")
            return None

        a = -1 if mode == "max" else 1
        best_path_metrics = sorted(checkpoint_paths, key=lambda x: a * x[1])

        best_path, best_metric = best_path_metrics[0]
        cloud_path = self._parse_cloud_path(best_path)

        if self._legacy_checkpoint:
            return TrialCheckpoint(local_path=best_path, cloud_path=cloud_path)

        if cloud_path:
            # Prefer cloud path over local path for downsteam processing
            return Checkpoint.from_uri(cloud_path)
        elif os.path.exists(best_path):
            return Checkpoint.from_directory(best_path)
        else:
            logger.error(
                f"No checkpoint locations for {trial} available on "
                f"this node. To avoid this, you "
                f"should enable checkpoint synchronization with the"
                f"`sync_config` argument in Ray Tune. "
                f"The checkpoint may be available on a different node - "
                f"please check this location on worker nodes: {best_path}")
            return None
예제 #5
0
    def testUploadDefaultBoth(self):
        state = {}

        def check_call(cmd, *args, **kwargs):
            state["cmd"] = cmd

        other_local_dir = "/tmp/other"
        other_cloud_dir = "s3://other"

        # Case: Nothing is passed
        checkpoint = TrialCheckpoint(local_path=self.local_dir,
                                     cloud_path=self.cloud_dir)

        with patch("subprocess.check_call", check_call):
            path = checkpoint.upload()

        self.assertEqual(self.cloud_dir, path)
        self.assertEqual(state["cmd"][0], "aws")
        self.assertIn(self.cloud_dir, state["cmd"])

        # Case: Local dir is passed
        checkpoint = TrialCheckpoint(local_path=self.local_dir,
                                     cloud_path=self.cloud_dir)

        with patch("subprocess.check_call", check_call):
            path = checkpoint.upload(local_path=other_local_dir)

        self.assertEqual(self.cloud_dir, path)
        self.assertEqual(state["cmd"][0], "aws")
        self.assertIn(other_local_dir, state["cmd"])
        self.assertNotIn(self.local_dir, state["cmd"])

        # Case: Both are passed
        checkpoint = TrialCheckpoint(local_path=self.local_dir,
                                     cloud_path=self.cloud_dir)

        with patch("subprocess.check_call", check_call):

            path = checkpoint.upload(local_path=other_local_dir,
                                     cloud_path=other_cloud_dir)

        self.assertEqual(other_cloud_dir, path)
        self.assertEqual(state["cmd"][0], "aws")
        self.assertIn(other_local_dir, state["cmd"])
        self.assertNotIn(self.local_dir, state["cmd"])
        self.assertIn(other_cloud_dir, state["cmd"])
        self.assertNotIn(self.cloud_dir, state["cmd"])
예제 #6
0
    def testDownloadDefaultCloud(self):
        state = {}

        def check_call(cmd, *args, **kwargs):
            state["cmd"] = cmd

        other_cloud_dir = "s3://other"

        # Case: Nothing is passed
        checkpoint = TrialCheckpoint(cloud_path=self.cloud_dir)
        with self.assertRaisesRegex(RuntimeError, "No local path"):
            checkpoint.download()

        # Case: Local dir is passed
        checkpoint = TrialCheckpoint(cloud_path=self.cloud_dir)
        with patch("subprocess.check_call", check_call):
            path = checkpoint.download(local_path=self.local_dir)

        self.assertEqual(self.local_dir, path)
        self.assertEqual(state["cmd"][0], "aws")
        self.assertIn(self.local_dir, state["cmd"])

        # Case: Cloud dir is passed
        checkpoint = TrialCheckpoint(cloud_path=self.cloud_dir)
        with self.assertRaisesRegex(RuntimeError, "No local path"):
            checkpoint.download(cloud_path=other_cloud_dir)

        # Case: Both are passed
        checkpoint = TrialCheckpoint(cloud_path=self.cloud_dir)
        with patch("subprocess.check_call", check_call):
            path = checkpoint.download(local_path=self.local_dir,
                                       cloud_path=other_cloud_dir)

        self.assertEqual(self.local_dir, path)
        self.assertEqual(state["cmd"][0], "aws")
        self.assertIn(other_cloud_dir, state["cmd"])
        self.assertNotIn(self.cloud_dir, state["cmd"])
예제 #7
0
    def testSaveCloudTarget(self):
        state = {}

        def check_call(cmd, *args, **kwargs):
            state["cmd"] = cmd

            # Fake AWS-specific checkpoint download
            local_dir = cmd[6]
            if not local_dir.startswith("s3"):
                with open(os.path.join(local_dir, "checkpoint.txt"),
                          "wt") as f:
                    f.write("Checkpoint\n")

        other_cloud_dir = "s3://other"

        # Case: No defaults
        checkpoint = TrialCheckpoint()
        with self.assertRaisesRegex(RuntimeError, "No existing local"):
            checkpoint.save(self.cloud_dir)

        # Case: Default local dir
        # Write a checkpoint here as we assume existing local dir
        with open(os.path.join(self.local_dir, "checkpoint.txt"), "wt") as f:
            f.write("Checkpoint\n")

        checkpoint = TrialCheckpoint(local_path=self.local_dir)
        with patch("subprocess.check_call", check_call):
            path = checkpoint.save(self.cloud_dir)

        self.assertEqual(self.cloud_dir, path)
        self.assertIn(self.cloud_dir, state["cmd"])
        self.assertIn(self.local_dir, state["cmd"])

        # Clean up checkpoint
        os.remove(os.path.join(self.local_dir, "checkpoint.txt"))

        # Case: Default cloud dir, copy to other cloud
        checkpoint = TrialCheckpoint(cloud_path=self.cloud_dir)

        with patch("subprocess.check_call", check_call):
            path = checkpoint.save(other_cloud_dir)

        self.assertEqual(other_cloud_dir, path)
        self.assertIn(other_cloud_dir, state["cmd"])
        self.assertNotIn(self.local_dir, state["cmd"])  # Temp dir

        # Case: Default both, copy to other cloud
        checkpoint = TrialCheckpoint(local_path=self.local_dir,
                                     cloud_path=self.cloud_dir)

        with patch("subprocess.check_call", check_call):
            path = checkpoint.save(other_cloud_dir)

        self.assertEqual(other_cloud_dir, path)
        self.assertIn(other_cloud_dir, state["cmd"])
        self.assertIn(self.local_dir, state["cmd"])
예제 #8
0
    def testSaveLocalTarget(self):
        state = {}

        def check_call(cmd, *args, **kwargs):
            state["cmd"] = cmd

        def copytree(source, dest):
            state["copy_source"] = source
            state["copy_dest"] = dest

        other_local_dir = "/tmp/other"

        # Case: No defaults
        checkpoint = TrialCheckpoint()
        with self.assertRaisesRegex(RuntimeError, "No cloud path"):
            checkpoint.save()

        # Case: Default local dir
        checkpoint = TrialCheckpoint(local_path=self.local_dir)

        with self.assertRaisesRegex(RuntimeError, "No cloud path"):
            checkpoint.save()

        # Case: Default cloud dir, no local dir passed
        checkpoint = TrialCheckpoint(cloud_path=self.cloud_dir)

        with self.assertRaisesRegex(RuntimeError, "No target path"):
            checkpoint.save()

        # Case: Default cloud dir, pass local dir
        checkpoint = TrialCheckpoint(cloud_path=self.cloud_dir)

        with patch("subprocess.check_call", check_call):
            path = checkpoint.save(self.local_dir, force_download=True)

        self.assertEqual(self.local_dir, path)
        self.assertEqual(state["cmd"][0], "aws")
        self.assertIn(self.cloud_dir, state["cmd"])
        self.assertIn(self.local_dir, state["cmd"])

        # Case: Default local dir, pass local dir
        checkpoint = TrialCheckpoint(local_path=self.local_dir)
        self.ensureCheckpointFile()

        with patch("shutil.copytree", copytree):
            path = checkpoint.save(other_local_dir)

        self.assertEqual(other_local_dir, path)
        self.assertEqual(state["copy_source"], self.local_dir)
        self.assertEqual(state["copy_dest"], other_local_dir)

        # Case: Both default, no pass
        checkpoint = TrialCheckpoint(local_path=self.local_dir,
                                     cloud_path=self.cloud_dir)

        with patch("subprocess.check_call", check_call):
            path = checkpoint.save()

        self.assertEqual(self.local_dir, path)
        self.assertIn(self.cloud_dir, state["cmd"])
        self.assertIn(self.local_dir, state["cmd"])

        # Case: Both default, pass other local dir
        checkpoint = TrialCheckpoint(local_path=self.local_dir,
                                     cloud_path=self.cloud_dir)

        with patch("shutil.copytree", copytree):
            path = checkpoint.save(other_local_dir)

        self.assertEqual(other_local_dir, path)
        self.assertEqual(state["copy_source"], self.local_dir)
        self.assertEqual(state["copy_dest"], other_local_dir)
        self.assertEqual(checkpoint.local_path, self.local_dir)
예제 #9
0
    def testUploadNoDefaults(self):
        state = {}

        def check_call(cmd, *args, **kwargs):
            state["cmd"] = cmd

        # Case: Nothing is passed
        checkpoint = TrialCheckpoint()
        with self.assertRaises(RuntimeError):
            checkpoint.upload()

        # Case: Local dir is passed
        checkpoint = TrialCheckpoint()
        with self.assertRaisesRegex(RuntimeError, "No cloud path"):
            checkpoint.upload(local_path=self.local_dir)

        # Case: Cloud dir is passed
        checkpoint = TrialCheckpoint()
        with self.assertRaisesRegex(RuntimeError, "No local path"):
            checkpoint.upload(cloud_path=self.cloud_dir)

        # Case: Both are passed
        checkpoint = TrialCheckpoint()
        with patch("subprocess.check_call", check_call):
            path = checkpoint.upload(local_path=self.local_dir,
                                     cloud_path=self.cloud_dir)

        self.assertEqual(self.cloud_dir, path)
        self.assertEqual(state["cmd"][0], "aws")
        self.assertIn(self.cloud_dir, state["cmd"])
예제 #10
0
    def testDownloadDefaultCloud(self):
        other_cloud_dir = "memory:///other"

        # Case: Nothing is passed
        checkpoint = TrialCheckpoint(cloud_path=self.cloud_dir)
        with self.assertRaisesRegex(RuntimeError, "No local path"):
            checkpoint.download()

        # Case: Local dir is passed
        checkpoint = TrialCheckpoint(cloud_path=self.cloud_dir)
        path = checkpoint.download(local_path=self.local_dir)

        self.assertEqual(self.local_dir, path)

        # Case: Cloud dir is passed
        checkpoint = TrialCheckpoint(cloud_path=self.cloud_dir)
        with self.assertRaisesRegex(RuntimeError, "No local path"):
            checkpoint.download(cloud_path=other_cloud_dir)

        # Case: Both are passed
        checkpoint = TrialCheckpoint(cloud_path=self.cloud_dir)
        path = checkpoint.download(local_path=self.local_dir,
                                   cloud_path=other_cloud_dir)

        self.assertEqual(self.local_dir, path)
예제 #11
0
    def testDownloadDefaultLocal(self):
        other_local_dir = "/tmp/invalid"

        # Case: Nothing is passed
        checkpoint = TrialCheckpoint(local_path=self.local_dir)
        with self.assertRaisesRegex(RuntimeError, "No cloud path"):
            checkpoint.download()

        # Case: Local dir is passed
        checkpoint = TrialCheckpoint(local_path=self.local_dir)
        with self.assertRaisesRegex(RuntimeError, "No cloud path"):
            checkpoint.download(local_path=other_local_dir)

        # Case: Cloud dir is passed
        checkpoint = TrialCheckpoint(local_path=self.local_dir)
        path = checkpoint.download(cloud_path=self.cloud_dir)

        self.assertEqual(self.local_dir, path)

        # Case: Both are passed
        checkpoint = TrialCheckpoint(local_path=self.local_dir)
        path = checkpoint.download(local_path=other_local_dir,
                                   cloud_path=self.cloud_dir)

        self.assertEqual(other_local_dir, path)
예제 #12
0
    def testDownloadNoDefaults(self):
        # Case: Nothing is passed
        checkpoint = TrialCheckpoint()
        with self.assertRaises(RuntimeError):
            checkpoint.download()

        # Case: Local dir is passed
        checkpoint = TrialCheckpoint()
        with self.assertRaisesRegex(RuntimeError, "No cloud path"):
            checkpoint.download(local_path=self.local_dir)

        # Case: Cloud dir is passed
        checkpoint = TrialCheckpoint()
        with self.assertRaisesRegex(RuntimeError, "No local path"):
            checkpoint.download(cloud_path=self.cloud_dir)

        # Case: Both are passed
        checkpoint = TrialCheckpoint()
        path = checkpoint.download(local_path=self.local_dir,
                                   cloud_path=self.cloud_dir)

        self.assertEqual(self.local_dir, path)
예제 #13
0
    def testSaveCloudTarget(self):
        other_cloud_dir = "memory:///other"

        delete_at_uri(other_cloud_dir)
        self._save_checkpoint_at(other_cloud_dir)

        # Case: No defaults
        checkpoint = TrialCheckpoint()
        with self.assertRaisesRegex(RuntimeError, "No existing local"):
            checkpoint.save(self.cloud_dir)

        # Case: Default local dir
        # Write a checkpoint here as we assume existing local dir
        with open(os.path.join(self.local_dir, "checkpoint.txt"), "wt") as f:
            f.write("Checkpoint\n")

        checkpoint = TrialCheckpoint(local_path=self.local_dir)
        path = checkpoint.save(self.cloud_dir)

        self.assertEqual(self.cloud_dir, path)

        # Clean up checkpoint
        os.remove(os.path.join(self.local_dir, "checkpoint.txt"))

        # Case: Default cloud dir, copy to other cloud
        checkpoint = TrialCheckpoint(cloud_path=self.cloud_dir)

        path = checkpoint.save(other_cloud_dir)

        self.assertEqual(other_cloud_dir, path)

        # Case: Default both, copy to other cloud
        checkpoint = TrialCheckpoint(local_path=self.local_dir,
                                     cloud_path=self.cloud_dir)

        path = checkpoint.save(other_cloud_dir)

        self.assertEqual(other_cloud_dir, path)
예제 #14
0
 def testConstructTrialCheckpoint(self):
     # All these constructions should work
     TrialCheckpoint(None, None)
     TrialCheckpoint("/tmp", None)
     TrialCheckpoint(None, "memory:///invalid")
     TrialCheckpoint("/remote/node/dir", None)