Ejemplo n.º 1
0
    def testDownloadDefaultBoth(self):
        other_local_dir = "/tmp/other"
        other_cloud_dir = "memory:///other"

        self._save_checkpoint_at(other_cloud_dir)
        self._save_checkpoint_at(self.cloud_dir)

        # Case: Nothing is passed
        checkpoint = TrialCheckpoint(local_path=self.local_dir,
                                     cloud_path=self.cloud_dir)

        path = checkpoint.download()

        self.assertEqual(self.local_dir, path)

        # Case: Local dir is passed
        checkpoint = TrialCheckpoint(local_path=self.local_dir,
                                     cloud_path=self.cloud_dir)

        path = checkpoint.download(local_path=other_local_dir)

        self.assertEqual(other_local_dir, path)

        # Case: Both are passed
        checkpoint = TrialCheckpoint(local_path=self.local_dir,
                                     cloud_path=self.cloud_dir)

        path = checkpoint.download(local_path=other_local_dir,
                                   cloud_path=other_cloud_dir)

        self.assertEqual(other_local_dir, path)
Ejemplo n.º 2
0
    def testDownloadDefaultBoth(self):
        state = {}

        def check_call(cmd, *args, **kwargs):
            state["cmd"] = cmd

        other_local_dir = "/tmp/other"
        other_cloud_dir = "s3://other"

        # Case: Nothing is passed
        checkpoint = TrialCheckpoint(local_path=self.local_dir,
                                     cloud_path=self.cloud_dir)

        with patch("subprocess.check_call", check_call):
            path = checkpoint.download()

        self.assertEqual(self.local_dir, path)
        self.assertEqual(state["cmd"][0], "aws")
        self.assertIn(self.local_dir, state["cmd"])

        # Case: Local dir is passed
        checkpoint = TrialCheckpoint(local_path=self.local_dir,
                                     cloud_path=self.cloud_dir)

        with patch("subprocess.check_call", check_call):
            path = checkpoint.download(local_path=other_local_dir)

        self.assertEqual(other_local_dir, path)
        self.assertEqual(state["cmd"][0], "aws")
        self.assertIn(other_local_dir, state["cmd"])
        self.assertNotIn(self.local_dir, state["cmd"])

        # Case: Both are passed
        checkpoint = TrialCheckpoint(local_path=self.local_dir,
                                     cloud_path=self.cloud_dir)

        with patch("subprocess.check_call", check_call):

            path = checkpoint.download(local_path=other_local_dir,
                                       cloud_path=other_cloud_dir)

        self.assertEqual(other_local_dir, path)
        self.assertEqual(state["cmd"][0], "aws")
        self.assertIn(other_local_dir, state["cmd"])
        self.assertNotIn(self.local_dir, state["cmd"])
        self.assertIn(other_cloud_dir, state["cmd"])
        self.assertNotIn(self.cloud_dir, state["cmd"])
Ejemplo n.º 3
0
    def testDownloadNoDefaults(self):
        state = {}

        def check_call(cmd, *args, **kwargs):
            state["cmd"] = cmd

        # Case: Nothing is passed
        checkpoint = TrialCheckpoint()
        with self.assertRaises(RuntimeError):
            checkpoint.download()

        # Case: Local dir is passed
        checkpoint = TrialCheckpoint()
        with self.assertRaisesRegex(RuntimeError, "No cloud path"):
            checkpoint.download(local_path=self.local_dir)

        # Case: Cloud dir is passed
        checkpoint = TrialCheckpoint()
        with self.assertRaisesRegex(RuntimeError, "No local path"):
            checkpoint.download(cloud_path=self.cloud_dir)

        # Case: Both are passed
        checkpoint = TrialCheckpoint()
        with patch("subprocess.check_call", check_call):
            path = checkpoint.download(local_path=self.local_dir,
                                       cloud_path=self.cloud_dir)

        self.assertEqual(self.local_dir, path)
        self.assertEqual(state["cmd"][0], "aws")
        self.assertIn(self.local_dir, state["cmd"])
Ejemplo n.º 4
0
    def testDownloadDefaultCloud(self):
        other_cloud_dir = "memory:///other"

        # Case: Nothing is passed
        checkpoint = TrialCheckpoint(cloud_path=self.cloud_dir)
        with self.assertRaisesRegex(RuntimeError, "No local path"):
            checkpoint.download()

        # Case: Local dir is passed
        checkpoint = TrialCheckpoint(cloud_path=self.cloud_dir)
        path = checkpoint.download(local_path=self.local_dir)

        self.assertEqual(self.local_dir, path)

        # Case: Cloud dir is passed
        checkpoint = TrialCheckpoint(cloud_path=self.cloud_dir)
        with self.assertRaisesRegex(RuntimeError, "No local path"):
            checkpoint.download(cloud_path=other_cloud_dir)

        # Case: Both are passed
        checkpoint = TrialCheckpoint(cloud_path=self.cloud_dir)
        path = checkpoint.download(local_path=self.local_dir,
                                   cloud_path=other_cloud_dir)

        self.assertEqual(self.local_dir, path)
Ejemplo n.º 5
0
    def testDownloadDefaultLocal(self):
        other_local_dir = "/tmp/invalid"

        # Case: Nothing is passed
        checkpoint = TrialCheckpoint(local_path=self.local_dir)
        with self.assertRaisesRegex(RuntimeError, "No cloud path"):
            checkpoint.download()

        # Case: Local dir is passed
        checkpoint = TrialCheckpoint(local_path=self.local_dir)
        with self.assertRaisesRegex(RuntimeError, "No cloud path"):
            checkpoint.download(local_path=other_local_dir)

        # Case: Cloud dir is passed
        checkpoint = TrialCheckpoint(local_path=self.local_dir)
        path = checkpoint.download(cloud_path=self.cloud_dir)

        self.assertEqual(self.local_dir, path)

        # Case: Both are passed
        checkpoint = TrialCheckpoint(local_path=self.local_dir)
        path = checkpoint.download(local_path=other_local_dir,
                                   cloud_path=self.cloud_dir)

        self.assertEqual(other_local_dir, path)
Ejemplo n.º 6
0
    def testDownloadNoDefaults(self):
        # Case: Nothing is passed
        checkpoint = TrialCheckpoint()
        with self.assertRaises(RuntimeError):
            checkpoint.download()

        # Case: Local dir is passed
        checkpoint = TrialCheckpoint()
        with self.assertRaisesRegex(RuntimeError, "No cloud path"):
            checkpoint.download(local_path=self.local_dir)

        # Case: Cloud dir is passed
        checkpoint = TrialCheckpoint()
        with self.assertRaisesRegex(RuntimeError, "No local path"):
            checkpoint.download(cloud_path=self.cloud_dir)

        # Case: Both are passed
        checkpoint = TrialCheckpoint()
        path = checkpoint.download(local_path=self.local_dir,
                                   cloud_path=self.cloud_dir)

        self.assertEqual(self.local_dir, path)