def testDownloadDefaultBoth(self): other_local_dir = "/tmp/other" other_cloud_dir = "memory:///other" self._save_checkpoint_at(other_cloud_dir) self._save_checkpoint_at(self.cloud_dir) # Case: Nothing is passed checkpoint = TrialCheckpoint(local_path=self.local_dir, cloud_path=self.cloud_dir) path = checkpoint.download() self.assertEqual(self.local_dir, path) # Case: Local dir is passed checkpoint = TrialCheckpoint(local_path=self.local_dir, cloud_path=self.cloud_dir) path = checkpoint.download(local_path=other_local_dir) self.assertEqual(other_local_dir, path) # Case: Both are passed checkpoint = TrialCheckpoint(local_path=self.local_dir, cloud_path=self.cloud_dir) path = checkpoint.download(local_path=other_local_dir, cloud_path=other_cloud_dir) self.assertEqual(other_local_dir, path)
def testDownloadDefaultBoth(self): state = {} def check_call(cmd, *args, **kwargs): state["cmd"] = cmd other_local_dir = "/tmp/other" other_cloud_dir = "s3://other" # Case: Nothing is passed checkpoint = TrialCheckpoint(local_path=self.local_dir, cloud_path=self.cloud_dir) with patch("subprocess.check_call", check_call): path = checkpoint.download() self.assertEqual(self.local_dir, path) self.assertEqual(state["cmd"][0], "aws") self.assertIn(self.local_dir, state["cmd"]) # Case: Local dir is passed checkpoint = TrialCheckpoint(local_path=self.local_dir, cloud_path=self.cloud_dir) with patch("subprocess.check_call", check_call): path = checkpoint.download(local_path=other_local_dir) self.assertEqual(other_local_dir, path) self.assertEqual(state["cmd"][0], "aws") self.assertIn(other_local_dir, state["cmd"]) self.assertNotIn(self.local_dir, state["cmd"]) # Case: Both are passed checkpoint = TrialCheckpoint(local_path=self.local_dir, cloud_path=self.cloud_dir) with patch("subprocess.check_call", check_call): path = checkpoint.download(local_path=other_local_dir, cloud_path=other_cloud_dir) self.assertEqual(other_local_dir, path) self.assertEqual(state["cmd"][0], "aws") self.assertIn(other_local_dir, state["cmd"]) self.assertNotIn(self.local_dir, state["cmd"]) self.assertIn(other_cloud_dir, state["cmd"]) self.assertNotIn(self.cloud_dir, state["cmd"])
def testDownloadNoDefaults(self): state = {} def check_call(cmd, *args, **kwargs): state["cmd"] = cmd # Case: Nothing is passed checkpoint = TrialCheckpoint() with self.assertRaises(RuntimeError): checkpoint.download() # Case: Local dir is passed checkpoint = TrialCheckpoint() with self.assertRaisesRegex(RuntimeError, "No cloud path"): checkpoint.download(local_path=self.local_dir) # Case: Cloud dir is passed checkpoint = TrialCheckpoint() with self.assertRaisesRegex(RuntimeError, "No local path"): checkpoint.download(cloud_path=self.cloud_dir) # Case: Both are passed checkpoint = TrialCheckpoint() with patch("subprocess.check_call", check_call): path = checkpoint.download(local_path=self.local_dir, cloud_path=self.cloud_dir) self.assertEqual(self.local_dir, path) self.assertEqual(state["cmd"][0], "aws") self.assertIn(self.local_dir, state["cmd"])
def testDownloadDefaultCloud(self): other_cloud_dir = "memory:///other" # Case: Nothing is passed checkpoint = TrialCheckpoint(cloud_path=self.cloud_dir) with self.assertRaisesRegex(RuntimeError, "No local path"): checkpoint.download() # Case: Local dir is passed checkpoint = TrialCheckpoint(cloud_path=self.cloud_dir) path = checkpoint.download(local_path=self.local_dir) self.assertEqual(self.local_dir, path) # Case: Cloud dir is passed checkpoint = TrialCheckpoint(cloud_path=self.cloud_dir) with self.assertRaisesRegex(RuntimeError, "No local path"): checkpoint.download(cloud_path=other_cloud_dir) # Case: Both are passed checkpoint = TrialCheckpoint(cloud_path=self.cloud_dir) path = checkpoint.download(local_path=self.local_dir, cloud_path=other_cloud_dir) self.assertEqual(self.local_dir, path)
def testDownloadDefaultLocal(self): other_local_dir = "/tmp/invalid" # Case: Nothing is passed checkpoint = TrialCheckpoint(local_path=self.local_dir) with self.assertRaisesRegex(RuntimeError, "No cloud path"): checkpoint.download() # Case: Local dir is passed checkpoint = TrialCheckpoint(local_path=self.local_dir) with self.assertRaisesRegex(RuntimeError, "No cloud path"): checkpoint.download(local_path=other_local_dir) # Case: Cloud dir is passed checkpoint = TrialCheckpoint(local_path=self.local_dir) path = checkpoint.download(cloud_path=self.cloud_dir) self.assertEqual(self.local_dir, path) # Case: Both are passed checkpoint = TrialCheckpoint(local_path=self.local_dir) path = checkpoint.download(local_path=other_local_dir, cloud_path=self.cloud_dir) self.assertEqual(other_local_dir, path)
def testDownloadNoDefaults(self): # Case: Nothing is passed checkpoint = TrialCheckpoint() with self.assertRaises(RuntimeError): checkpoint.download() # Case: Local dir is passed checkpoint = TrialCheckpoint() with self.assertRaisesRegex(RuntimeError, "No cloud path"): checkpoint.download(local_path=self.local_dir) # Case: Cloud dir is passed checkpoint = TrialCheckpoint() with self.assertRaisesRegex(RuntimeError, "No local path"): checkpoint.download(cloud_path=self.cloud_dir) # Case: Both are passed checkpoint = TrialCheckpoint() path = checkpoint.download(local_path=self.local_dir, cloud_path=self.cloud_dir) self.assertEqual(self.local_dir, path)