Exemple #1
0
    def testSaveCloudTarget(self):
        state = {}

        def check_call(cmd, *args, **kwargs):
            state["cmd"] = cmd

            # Fake AWS-specific checkpoint download
            local_dir = cmd[6]
            if not local_dir.startswith("s3"):
                with open(os.path.join(local_dir, "checkpoint.txt"),
                          "wt") as f:
                    f.write("Checkpoint\n")

        other_cloud_dir = "s3://other"

        # Case: No defaults
        checkpoint = TrialCheckpoint()
        with self.assertRaisesRegex(RuntimeError, "No existing local"):
            checkpoint.save(self.cloud_dir)

        # Case: Default local dir
        # Write a checkpoint here as we assume existing local dir
        with open(os.path.join(self.local_dir, "checkpoint.txt"), "wt") as f:
            f.write("Checkpoint\n")

        checkpoint = TrialCheckpoint(local_path=self.local_dir)
        with patch("subprocess.check_call", check_call):
            path = checkpoint.save(self.cloud_dir)

        self.assertEqual(self.cloud_dir, path)
        self.assertIn(self.cloud_dir, state["cmd"])
        self.assertIn(self.local_dir, state["cmd"])

        # Clean up checkpoint
        os.remove(os.path.join(self.local_dir, "checkpoint.txt"))

        # Case: Default cloud dir, copy to other cloud
        checkpoint = TrialCheckpoint(cloud_path=self.cloud_dir)

        with patch("subprocess.check_call", check_call):
            path = checkpoint.save(other_cloud_dir)

        self.assertEqual(other_cloud_dir, path)
        self.assertIn(other_cloud_dir, state["cmd"])
        self.assertNotIn(self.local_dir, state["cmd"])  # Temp dir

        # Case: Default both, copy to other cloud
        checkpoint = TrialCheckpoint(local_path=self.local_dir,
                                     cloud_path=self.cloud_dir)

        with patch("subprocess.check_call", check_call):
            path = checkpoint.save(other_cloud_dir)

        self.assertEqual(other_cloud_dir, path)
        self.assertIn(other_cloud_dir, state["cmd"])
        self.assertIn(self.local_dir, state["cmd"])
Exemple #2
0
    def testSaveCloudTarget(self):
        other_cloud_dir = "memory:///other"

        delete_at_uri(other_cloud_dir)
        self._save_checkpoint_at(other_cloud_dir)

        # Case: No defaults
        checkpoint = TrialCheckpoint()
        with self.assertRaisesRegex(RuntimeError, "No existing local"):
            checkpoint.save(self.cloud_dir)

        # Case: Default local dir
        # Write a checkpoint here as we assume existing local dir
        with open(os.path.join(self.local_dir, "checkpoint.txt"), "wt") as f:
            f.write("Checkpoint\n")

        checkpoint = TrialCheckpoint(local_path=self.local_dir)
        path = checkpoint.save(self.cloud_dir)

        self.assertEqual(self.cloud_dir, path)

        # Clean up checkpoint
        os.remove(os.path.join(self.local_dir, "checkpoint.txt"))

        # Case: Default cloud dir, copy to other cloud
        checkpoint = TrialCheckpoint(cloud_path=self.cloud_dir)

        path = checkpoint.save(other_cloud_dir)

        self.assertEqual(other_cloud_dir, path)

        # Case: Default both, copy to other cloud
        checkpoint = TrialCheckpoint(local_path=self.local_dir,
                                     cloud_path=self.cloud_dir)

        path = checkpoint.save(other_cloud_dir)

        self.assertEqual(other_cloud_dir, path)
Exemple #3
0
    def testSaveLocalTarget(self):
        state = {}

        def check_call(cmd, *args, **kwargs):
            state["cmd"] = cmd

        def copytree(source, dest):
            state["copy_source"] = source
            state["copy_dest"] = dest

        other_local_dir = "/tmp/other"

        # Case: No defaults
        checkpoint = TrialCheckpoint()
        with self.assertRaisesRegex(RuntimeError, "No cloud path"):
            checkpoint.save()

        # Case: Default local dir
        checkpoint = TrialCheckpoint(local_path=self.local_dir)

        with self.assertRaisesRegex(RuntimeError, "No cloud path"):
            checkpoint.save()

        # Case: Default cloud dir, no local dir passed
        checkpoint = TrialCheckpoint(cloud_path=self.cloud_dir)

        with self.assertRaisesRegex(RuntimeError, "No target path"):
            checkpoint.save()

        # Case: Default cloud dir, pass local dir
        checkpoint = TrialCheckpoint(cloud_path=self.cloud_dir)

        with patch("subprocess.check_call", check_call):
            path = checkpoint.save(self.local_dir, force_download=True)

        self.assertEqual(self.local_dir, path)
        self.assertEqual(state["cmd"][0], "aws")
        self.assertIn(self.cloud_dir, state["cmd"])
        self.assertIn(self.local_dir, state["cmd"])

        # Case: Default local dir, pass local dir
        checkpoint = TrialCheckpoint(local_path=self.local_dir)
        self.ensureCheckpointFile()

        with patch("shutil.copytree", copytree):
            path = checkpoint.save(other_local_dir)

        self.assertEqual(other_local_dir, path)
        self.assertEqual(state["copy_source"], self.local_dir)
        self.assertEqual(state["copy_dest"], other_local_dir)

        # Case: Both default, no pass
        checkpoint = TrialCheckpoint(local_path=self.local_dir,
                                     cloud_path=self.cloud_dir)

        with patch("subprocess.check_call", check_call):
            path = checkpoint.save()

        self.assertEqual(self.local_dir, path)
        self.assertIn(self.cloud_dir, state["cmd"])
        self.assertIn(self.local_dir, state["cmd"])

        # Case: Both default, pass other local dir
        checkpoint = TrialCheckpoint(local_path=self.local_dir,
                                     cloud_path=self.cloud_dir)

        with patch("shutil.copytree", copytree):
            path = checkpoint.save(other_local_dir)

        self.assertEqual(other_local_dir, path)
        self.assertEqual(state["copy_source"], self.local_dir)
        self.assertEqual(state["copy_dest"], other_local_dir)
        self.assertEqual(checkpoint.local_path, self.local_dir)