def testDownloadNoDefaults(self): state = {} def check_call(cmd, *args, **kwargs): state["cmd"] = cmd # Case: Nothing is passed checkpoint = TrialCheckpoint() with self.assertRaises(RuntimeError): checkpoint.download() # Case: Local dir is passed checkpoint = TrialCheckpoint() with self.assertRaisesRegex(RuntimeError, "No cloud path"): checkpoint.download(local_path=self.local_dir) # Case: Cloud dir is passed checkpoint = TrialCheckpoint() with self.assertRaisesRegex(RuntimeError, "No local path"): checkpoint.download(cloud_path=self.cloud_dir) # Case: Both are passed checkpoint = TrialCheckpoint() with patch("subprocess.check_call", check_call): path = checkpoint.download(local_path=self.local_dir, cloud_path=self.cloud_dir) self.assertEqual(self.local_dir, path) self.assertEqual(state["cmd"][0], "aws") self.assertIn(self.local_dir, state["cmd"])
def testDownloadDefaultLocal(self): other_local_dir = "/tmp/invalid" # Case: Nothing is passed checkpoint = TrialCheckpoint(local_path=self.local_dir) with self.assertRaisesRegex(RuntimeError, "No cloud path"): checkpoint.download() # Case: Local dir is passed checkpoint = TrialCheckpoint(local_path=self.local_dir) with self.assertRaisesRegex(RuntimeError, "No cloud path"): checkpoint.download(local_path=other_local_dir) # Case: Cloud dir is passed checkpoint = TrialCheckpoint(local_path=self.local_dir) path = checkpoint.download(cloud_path=self.cloud_dir) self.assertEqual(self.local_dir, path) # Case: Both are passed checkpoint = TrialCheckpoint(local_path=self.local_dir) path = checkpoint.download(local_path=other_local_dir, cloud_path=self.cloud_dir) self.assertEqual(other_local_dir, path)
def testDownloadDefaultCloud(self): other_cloud_dir = "memory:///other" # Case: Nothing is passed checkpoint = TrialCheckpoint(cloud_path=self.cloud_dir) with self.assertRaisesRegex(RuntimeError, "No local path"): checkpoint.download() # Case: Local dir is passed checkpoint = TrialCheckpoint(cloud_path=self.cloud_dir) path = checkpoint.download(local_path=self.local_dir) self.assertEqual(self.local_dir, path) # Case: Cloud dir is passed checkpoint = TrialCheckpoint(cloud_path=self.cloud_dir) with self.assertRaisesRegex(RuntimeError, "No local path"): checkpoint.download(cloud_path=other_cloud_dir) # Case: Both are passed checkpoint = TrialCheckpoint(cloud_path=self.cloud_dir) path = checkpoint.download(local_path=self.local_dir, cloud_path=other_cloud_dir) self.assertEqual(self.local_dir, path)
def testUploadDefaultBoth(self): other_local_dir = "/tmp/other" other_cloud_dir = "memory:///other" delete_at_uri(other_cloud_dir) self._save_checkpoint_at(other_cloud_dir) shutil.copytree(self.local_dir, other_local_dir) # Case: Nothing is passed checkpoint = TrialCheckpoint(local_path=self.local_dir, cloud_path=self.cloud_dir) path = checkpoint.upload() self.assertEqual(self.cloud_dir, path) # Case: Local dir is passed checkpoint = TrialCheckpoint(local_path=self.local_dir, cloud_path=self.cloud_dir) path = checkpoint.upload(local_path=other_local_dir) self.assertEqual(self.cloud_dir, path) # Case: Both are passed checkpoint = TrialCheckpoint(local_path=self.local_dir, cloud_path=self.cloud_dir) path = checkpoint.upload(local_path=other_local_dir, cloud_path=other_cloud_dir) self.assertEqual(other_cloud_dir, path)
def testDownloadDefaultBoth(self): other_local_dir = "/tmp/other" other_cloud_dir = "memory:///other" self._save_checkpoint_at(other_cloud_dir) self._save_checkpoint_at(self.cloud_dir) # Case: Nothing is passed checkpoint = TrialCheckpoint(local_path=self.local_dir, cloud_path=self.cloud_dir) path = checkpoint.download() self.assertEqual(self.local_dir, path) # Case: Local dir is passed checkpoint = TrialCheckpoint(local_path=self.local_dir, cloud_path=self.cloud_dir) path = checkpoint.download(local_path=other_local_dir) self.assertEqual(other_local_dir, path) # Case: Both are passed checkpoint = TrialCheckpoint(local_path=self.local_dir, cloud_path=self.cloud_dir) path = checkpoint.download(local_path=other_local_dir, cloud_path=other_cloud_dir) self.assertEqual(other_local_dir, path)
def testSaveCloudTarget(self): state = {} def check_call(cmd, *args, **kwargs): state["cmd"] = cmd # Fake AWS-specific checkpoint download local_dir = cmd[6] if not local_dir.startswith("s3"): with open(os.path.join(local_dir, "checkpoint.txt"), "wt") as f: f.write("Checkpoint\n") other_cloud_dir = "s3://other" # Case: No defaults checkpoint = TrialCheckpoint() with self.assertRaisesRegex(RuntimeError, "No existing local"): checkpoint.save(self.cloud_dir) # Case: Default local dir # Write a checkpoint here as we assume existing local dir with open(os.path.join(self.local_dir, "checkpoint.txt"), "wt") as f: f.write("Checkpoint\n") checkpoint = TrialCheckpoint(local_path=self.local_dir) with patch("subprocess.check_call", check_call): path = checkpoint.save(self.cloud_dir) self.assertEqual(self.cloud_dir, path) self.assertIn(self.cloud_dir, state["cmd"]) self.assertIn(self.local_dir, state["cmd"]) # Clean up checkpoint os.remove(os.path.join(self.local_dir, "checkpoint.txt")) # Case: Default cloud dir, copy to other cloud checkpoint = TrialCheckpoint(cloud_path=self.cloud_dir) with patch("subprocess.check_call", check_call): path = checkpoint.save(other_cloud_dir) self.assertEqual(other_cloud_dir, path) self.assertIn(other_cloud_dir, state["cmd"]) self.assertNotIn(self.local_dir, state["cmd"]) # Temp dir # Case: Default both, copy to other cloud checkpoint = TrialCheckpoint(local_path=self.local_dir, cloud_path=self.cloud_dir) with patch("subprocess.check_call", check_call): path = checkpoint.save(other_cloud_dir) self.assertEqual(other_cloud_dir, path) self.assertIn(other_cloud_dir, state["cmd"]) self.assertIn(self.local_dir, state["cmd"])
def get_best_checkpoint( self, trial: Trial, metric: Optional[str] = None, mode: Optional[str] = None ) -> Optional[TrialCheckpoint]: """Gets best persistent checkpoint path of provided trial. Args: trial (Trial): The log directory of a trial, or a trial instance. metric (str): key of trial info to return, e.g. "mean_accuracy". "training_iteration" is used by default if no value was passed to ``self.default_metric``. mode (str): One of [min, max]. Defaults to ``self.default_mode``. Returns: :class:`TrialCheckpoint <ray.tune.cloud.TrialCheckpoint>` object. """ metric = metric or self.default_metric or TRAINING_ITERATION mode = self._validate_mode(mode) checkpoint_paths = self.get_trial_checkpoints_paths(trial, metric) if not checkpoint_paths: logger.error(f"No checkpoints have been found for trial {trial}.") return None a = -1 if mode == "max" else 1 best_path_metrics = sorted(checkpoint_paths, key=lambda x: a * x[1]) best_path, best_metric = best_path_metrics[0] return TrialCheckpoint( local_path=best_path, cloud_path=self._parse_cloud_path(best_path) )
def testUploadDefaultBoth(self): state = {} def check_call(cmd, *args, **kwargs): state["cmd"] = cmd other_local_dir = "/tmp/other" other_cloud_dir = "s3://other" # Case: Nothing is passed checkpoint = TrialCheckpoint(local_path=self.local_dir, cloud_path=self.cloud_dir) with patch("subprocess.check_call", check_call): path = checkpoint.upload() self.assertEqual(self.cloud_dir, path) self.assertEqual(state["cmd"][0], "aws") self.assertIn(self.cloud_dir, state["cmd"]) # Case: Local dir is passed checkpoint = TrialCheckpoint(local_path=self.local_dir, cloud_path=self.cloud_dir) with patch("subprocess.check_call", check_call): path = checkpoint.upload(local_path=other_local_dir) self.assertEqual(self.cloud_dir, path) self.assertEqual(state["cmd"][0], "aws") self.assertIn(other_local_dir, state["cmd"]) self.assertNotIn(self.local_dir, state["cmd"]) # Case: Both are passed checkpoint = TrialCheckpoint(local_path=self.local_dir, cloud_path=self.cloud_dir) with patch("subprocess.check_call", check_call): path = checkpoint.upload(local_path=other_local_dir, cloud_path=other_cloud_dir) self.assertEqual(other_cloud_dir, path) self.assertEqual(state["cmd"][0], "aws") self.assertIn(other_local_dir, state["cmd"]) self.assertNotIn(self.local_dir, state["cmd"]) self.assertIn(other_cloud_dir, state["cmd"]) self.assertNotIn(self.cloud_dir, state["cmd"])
def get_best_checkpoint( self, trial: Trial, metric: Optional[str] = None, mode: Optional[str] = None) -> Optional[Checkpoint]: """Gets best persistent checkpoint path of provided trial. Any checkpoints with an associated metric value of ``nan`` will be filtered out. Args: trial: The log directory of a trial, or a trial instance. metric: key of trial info to return, e.g. "mean_accuracy". "training_iteration" is used by default if no value was passed to ``self.default_metric``. mode: One of [min, max]. Defaults to ``self.default_mode``. Returns: :class:`Checkpoint <ray.ml.Checkpoint>` object. """ metric = metric or self.default_metric or TRAINING_ITERATION mode = self._validate_mode(mode) checkpoint_paths = self.get_trial_checkpoints_paths(trial, metric) # Filter out nan. Sorting nan values leads to undefined behavior. checkpoint_paths = [(path, metric) for path, metric in checkpoint_paths if not is_nan(metric)] if not checkpoint_paths: logger.error(f"No checkpoints have been found for trial {trial}.") return None a = -1 if mode == "max" else 1 best_path_metrics = sorted(checkpoint_paths, key=lambda x: a * x[1]) best_path, best_metric = best_path_metrics[0] cloud_path = self._parse_cloud_path(best_path) if self._legacy_checkpoint: return TrialCheckpoint(local_path=best_path, cloud_path=cloud_path) if cloud_path: # Prefer cloud path over local path for downsteam processing return Checkpoint.from_uri(cloud_path) elif os.path.exists(best_path): return Checkpoint.from_directory(best_path) else: logger.error( f"No checkpoint locations for {trial} available on " f"this node. To avoid this, you " f"should enable checkpoint synchronization with the" f"`sync_config` argument in Ray Tune. " f"The checkpoint may be available on a different node - " f"please check this location on worker nodes: {best_path}") return None
def testDownloadNoDefaults(self): # Case: Nothing is passed checkpoint = TrialCheckpoint() with self.assertRaises(RuntimeError): checkpoint.download() # Case: Local dir is passed checkpoint = TrialCheckpoint() with self.assertRaisesRegex(RuntimeError, "No cloud path"): checkpoint.download(local_path=self.local_dir) # Case: Cloud dir is passed checkpoint = TrialCheckpoint() with self.assertRaisesRegex(RuntimeError, "No local path"): checkpoint.download(cloud_path=self.cloud_dir) # Case: Both are passed checkpoint = TrialCheckpoint() path = checkpoint.download(local_path=self.local_dir, cloud_path=self.cloud_dir) self.assertEqual(self.local_dir, path)
def testSaveCloudTarget(self): other_cloud_dir = "memory:///other" delete_at_uri(other_cloud_dir) self._save_checkpoint_at(other_cloud_dir) # Case: No defaults checkpoint = TrialCheckpoint() with self.assertRaisesRegex(RuntimeError, "No existing local"): checkpoint.save(self.cloud_dir) # Case: Default local dir # Write a checkpoint here as we assume existing local dir with open(os.path.join(self.local_dir, "checkpoint.txt"), "wt") as f: f.write("Checkpoint\n") checkpoint = TrialCheckpoint(local_path=self.local_dir) path = checkpoint.save(self.cloud_dir) self.assertEqual(self.cloud_dir, path) # Clean up checkpoint os.remove(os.path.join(self.local_dir, "checkpoint.txt")) # Case: Default cloud dir, copy to other cloud checkpoint = TrialCheckpoint(cloud_path=self.cloud_dir) path = checkpoint.save(other_cloud_dir) self.assertEqual(other_cloud_dir, path) # Case: Default both, copy to other cloud checkpoint = TrialCheckpoint(local_path=self.local_dir, cloud_path=self.cloud_dir) path = checkpoint.save(other_cloud_dir) self.assertEqual(other_cloud_dir, path)
def testUploadDefaultCloud(self): state = {} def check_call(cmd, *args, **kwargs): state["cmd"] = cmd other_cloud_dir = "s3://other" # Case: Nothing is passed checkpoint = TrialCheckpoint(cloud_path=self.cloud_dir) with self.assertRaisesRegex(RuntimeError, "No local path"): checkpoint.upload() # Case: Local dir is passed checkpoint = TrialCheckpoint(cloud_path=self.cloud_dir) with patch("subprocess.check_call", check_call): path = checkpoint.upload(local_path=self.local_dir) self.assertEqual(self.cloud_dir, path) self.assertEqual(state["cmd"][0], "aws") self.assertIn(self.cloud_dir, state["cmd"]) # Case: Cloud dir is passed checkpoint = TrialCheckpoint(cloud_path=self.cloud_dir) with self.assertRaisesRegex(RuntimeError, "No local path"): checkpoint.upload(cloud_path=other_cloud_dir) # Case: Both are passed checkpoint = TrialCheckpoint(cloud_path=self.cloud_dir) with patch("subprocess.check_call", check_call): path = checkpoint.upload(local_path=self.local_dir, cloud_path=other_cloud_dir) self.assertEqual(other_cloud_dir, path) self.assertEqual(state["cmd"][0], "aws") self.assertIn(other_cloud_dir, state["cmd"]) self.assertNotIn(self.cloud_dir, state["cmd"])
def testSaveLocalTarget(self): state = {} def check_call(cmd, *args, **kwargs): state["cmd"] = cmd def copytree(source, dest): state["copy_source"] = source state["copy_dest"] = dest other_local_dir = "/tmp/other" # Case: No defaults checkpoint = TrialCheckpoint() with self.assertRaisesRegex(RuntimeError, "No cloud path"): checkpoint.save() # Case: Default local dir checkpoint = TrialCheckpoint(local_path=self.local_dir) with self.assertRaisesRegex(RuntimeError, "No cloud path"): checkpoint.save() # Case: Default cloud dir, no local dir passed checkpoint = TrialCheckpoint(cloud_path=self.cloud_dir) with self.assertRaisesRegex(RuntimeError, "No target path"): checkpoint.save() # Case: Default cloud dir, pass local dir checkpoint = TrialCheckpoint(cloud_path=self.cloud_dir) with patch("subprocess.check_call", check_call): path = checkpoint.save(self.local_dir, force_download=True) self.assertEqual(self.local_dir, path) self.assertEqual(state["cmd"][0], "aws") self.assertIn(self.cloud_dir, state["cmd"]) self.assertIn(self.local_dir, state["cmd"]) # Case: Default local dir, pass local dir checkpoint = TrialCheckpoint(local_path=self.local_dir) self.ensureCheckpointFile() with patch("shutil.copytree", copytree): path = checkpoint.save(other_local_dir) self.assertEqual(other_local_dir, path) self.assertEqual(state["copy_source"], self.local_dir) self.assertEqual(state["copy_dest"], other_local_dir) # Case: Both default, no pass checkpoint = TrialCheckpoint(local_path=self.local_dir, cloud_path=self.cloud_dir) with patch("subprocess.check_call", check_call): path = checkpoint.save() self.assertEqual(self.local_dir, path) self.assertIn(self.cloud_dir, state["cmd"]) self.assertIn(self.local_dir, state["cmd"]) # Case: Both default, pass other local dir checkpoint = TrialCheckpoint(local_path=self.local_dir, cloud_path=self.cloud_dir) with patch("shutil.copytree", copytree): path = checkpoint.save(other_local_dir) self.assertEqual(other_local_dir, path) self.assertEqual(state["copy_source"], self.local_dir) self.assertEqual(state["copy_dest"], other_local_dir) self.assertEqual(checkpoint.local_path, self.local_dir)
def testConstructTrialCheckpoint(self): # All these constructions should work TrialCheckpoint(None, None) TrialCheckpoint("/tmp", None) TrialCheckpoint(None, "memory:///invalid") TrialCheckpoint("/remote/node/dir", None)