def testSaveCloudTarget(self): state = {} def check_call(cmd, *args, **kwargs): state["cmd"] = cmd # Fake AWS-specific checkpoint download local_dir = cmd[6] if not local_dir.startswith("s3"): with open(os.path.join(local_dir, "checkpoint.txt"), "wt") as f: f.write("Checkpoint\n") other_cloud_dir = "s3://other" # Case: No defaults checkpoint = TrialCheckpoint() with self.assertRaisesRegex(RuntimeError, "No existing local"): checkpoint.save(self.cloud_dir) # Case: Default local dir # Write a checkpoint here as we assume existing local dir with open(os.path.join(self.local_dir, "checkpoint.txt"), "wt") as f: f.write("Checkpoint\n") checkpoint = TrialCheckpoint(local_path=self.local_dir) with patch("subprocess.check_call", check_call): path = checkpoint.save(self.cloud_dir) self.assertEqual(self.cloud_dir, path) self.assertIn(self.cloud_dir, state["cmd"]) self.assertIn(self.local_dir, state["cmd"]) # Clean up checkpoint os.remove(os.path.join(self.local_dir, "checkpoint.txt")) # Case: Default cloud dir, copy to other cloud checkpoint = TrialCheckpoint(cloud_path=self.cloud_dir) with patch("subprocess.check_call", check_call): path = checkpoint.save(other_cloud_dir) self.assertEqual(other_cloud_dir, path) self.assertIn(other_cloud_dir, state["cmd"]) self.assertNotIn(self.local_dir, state["cmd"]) # Temp dir # Case: Default both, copy to other cloud checkpoint = TrialCheckpoint(local_path=self.local_dir, cloud_path=self.cloud_dir) with patch("subprocess.check_call", check_call): path = checkpoint.save(other_cloud_dir) self.assertEqual(other_cloud_dir, path) self.assertIn(other_cloud_dir, state["cmd"]) self.assertIn(self.local_dir, state["cmd"])
def testSaveCloudTarget(self): other_cloud_dir = "memory:///other" delete_at_uri(other_cloud_dir) self._save_checkpoint_at(other_cloud_dir) # Case: No defaults checkpoint = TrialCheckpoint() with self.assertRaisesRegex(RuntimeError, "No existing local"): checkpoint.save(self.cloud_dir) # Case: Default local dir # Write a checkpoint here as we assume existing local dir with open(os.path.join(self.local_dir, "checkpoint.txt"), "wt") as f: f.write("Checkpoint\n") checkpoint = TrialCheckpoint(local_path=self.local_dir) path = checkpoint.save(self.cloud_dir) self.assertEqual(self.cloud_dir, path) # Clean up checkpoint os.remove(os.path.join(self.local_dir, "checkpoint.txt")) # Case: Default cloud dir, copy to other cloud checkpoint = TrialCheckpoint(cloud_path=self.cloud_dir) path = checkpoint.save(other_cloud_dir) self.assertEqual(other_cloud_dir, path) # Case: Default both, copy to other cloud checkpoint = TrialCheckpoint(local_path=self.local_dir, cloud_path=self.cloud_dir) path = checkpoint.save(other_cloud_dir) self.assertEqual(other_cloud_dir, path)
def testSaveLocalTarget(self): state = {} def check_call(cmd, *args, **kwargs): state["cmd"] = cmd def copytree(source, dest): state["copy_source"] = source state["copy_dest"] = dest other_local_dir = "/tmp/other" # Case: No defaults checkpoint = TrialCheckpoint() with self.assertRaisesRegex(RuntimeError, "No cloud path"): checkpoint.save() # Case: Default local dir checkpoint = TrialCheckpoint(local_path=self.local_dir) with self.assertRaisesRegex(RuntimeError, "No cloud path"): checkpoint.save() # Case: Default cloud dir, no local dir passed checkpoint = TrialCheckpoint(cloud_path=self.cloud_dir) with self.assertRaisesRegex(RuntimeError, "No target path"): checkpoint.save() # Case: Default cloud dir, pass local dir checkpoint = TrialCheckpoint(cloud_path=self.cloud_dir) with patch("subprocess.check_call", check_call): path = checkpoint.save(self.local_dir, force_download=True) self.assertEqual(self.local_dir, path) self.assertEqual(state["cmd"][0], "aws") self.assertIn(self.cloud_dir, state["cmd"]) self.assertIn(self.local_dir, state["cmd"]) # Case: Default local dir, pass local dir checkpoint = TrialCheckpoint(local_path=self.local_dir) self.ensureCheckpointFile() with patch("shutil.copytree", copytree): path = checkpoint.save(other_local_dir) self.assertEqual(other_local_dir, path) self.assertEqual(state["copy_source"], self.local_dir) self.assertEqual(state["copy_dest"], other_local_dir) # Case: Both default, no pass checkpoint = TrialCheckpoint(local_path=self.local_dir, cloud_path=self.cloud_dir) with patch("subprocess.check_call", check_call): path = checkpoint.save() self.assertEqual(self.local_dir, path) self.assertIn(self.cloud_dir, state["cmd"]) self.assertIn(self.local_dir, state["cmd"]) # Case: Both default, pass other local dir checkpoint = TrialCheckpoint(local_path=self.local_dir, cloud_path=self.cloud_dir) with patch("shutil.copytree", copytree): path = checkpoint.save(other_local_dir) self.assertEqual(other_local_dir, path) self.assertEqual(state["copy_source"], self.local_dir) self.assertEqual(state["copy_dest"], other_local_dir) self.assertEqual(checkpoint.local_path, self.local_dir)