Ejemplo n.º 1
0
 def setUp(self):
     self.test_dir = tempfile.gettempdir()
     self.experiment = PytorchExperiment(name="test_experiment",
                                         base_dir=self.test_dir,
                                         n_epochs=10)
Ejemplo n.º 2
0
    def test_resume(self):
        # TODO

        self.cntr = []
        self.experiment.train = lambda epoch: self.cntr.append(0)
        self.experiment.run()
        self.assertTrue(
            len(self.cntr) == 10,
            "Did not call train for the right number of epochs")

        exp2 = PytorchExperiment(name="test-exp2",
                                 base_dir=self.test_dir,
                                 resume=self.experiment.elog.work_dir)
        exp3 = PytorchExperiment(name="test-exp2",
                                 base_dir=self.test_dir,
                                 resume=self.experiment.elog.work_dir,
                                 resume_reset_epochs=False)

        exp2.prepare_resume()
        exp3.prepare_resume()

        self.assertTrue(exp2.exp_name == 'test_experiment',
                        "Did not restore exp_name")
        self.assertTrue(exp2._epoch_idx == 0, "Did not reset epochs")
        self.assertTrue(exp3._epoch_idx == 10, "Did reset epochs")

        exp2.train = lambda epoch: self.cntr.append(0)
        exp3.train = lambda epoch: self.cntr.append(0)
        exp2.run()
        self.assertTrue(len(self.cntr) == 20, "Did call not train for exp2")
        exp3.run()
        self.assertTrue(len(self.cntr) == 20, "Did call train for exp3")

        exp2._exp_state = "Ended"
        exp3._exp_state = "Ended"
Ejemplo n.º 3
0
class TestPytorchExperiment(unittest.TestCase):
    def setUp(self):
        self.test_dir = tempfile.gettempdir()
        self.experiment = PytorchExperiment(name="test_experiment",
                                            base_dir=self.test_dir,
                                            n_epochs=10)

    def tearDown(self):
        self.experiment._exp_state = "Ended"
        if os.path.exists(self.test_dir):
            shutil.rmtree(self.test_dir, ignore_errors=True)

    def test_train(self):
        self.cntr = []
        self.experiment.train = lambda epoch: self.cntr.append(0)
        self.experiment.run()
        self.assertTrue(
            len(self.cntr) == 10,
            "Did not call train for the right number of epochs")

    def test_validate(self):
        self.cntr = []
        self.experiment.validate = lambda epoch: self.cntr.append(0)
        self.experiment.run()
        self.assertTrue(
            len(self.cntr) == 10,
            "Did not call train for the right number of epochs")

    def test_update_attributes(self):
        results2 = ResultLogDict("results-log.json", base_dir=self.test_dir)
        results2["test"] = 0
        self.experiment.test_var = "test"

        self.experiment.update_attributes({
            "results": results2,
            "test_var": "test2"
        })

        self.assertTrue(self.experiment.test_var == "test2",
                        "Could not update simple var in update_attributes")
        self.assertTrue(self.experiment.results['test'] == 0,
                        "Could not result dict var in update_attributes")

    def test_get_pytorch_modules(self):
        module = Net()
        self.experiment.net = module

        pyt_module = self.experiment.get_pytorch_modules()

        self.assertTrue(
            "net" in pyt_module,
            "Name for module could not be correctly returned in get_pytorch_modules"
        )
        self.assertTrue(
            pyt_module['net'] == module,
            "Module could not be correctly returned in  get_pytorch_modules")

    def test_get_pytorch_optimizers(self):
        optimizer = torch.optim.Adam(params=[torch.tensor(1)])
        self.experiment.optim = optimizer

        pyt_optim = self.experiment.get_pytorch_optimizers()

        self.assertTrue(
            "optim" in pyt_optim,
            "Name for optimizer could not be correctly returned in "
            "get_pytorch_optimizers")
        self.assertTrue(
            pyt_optim['optim'] == optimizer,
            "Optimizer could not be correctly returned in  "
            "get_pytorch_optimizers")

    def test_get_simple_vars(self):
        self.experiment.a = "test"
        self.experiment.b = 1
        self.experiment.c = True
        self.experiment.d = (1, 2, 3, 4)

        var = self.experiment.get_simple_variables()

        self.assertTrue(
            "a" in var and "b" in var and "c" in var and "d" in var,
            "Names for attributes could not be correctly returned in get_simple_variables"
        )
        self.assertTrue(
            var['a'] == "test" and var['b'] == 1 and var['c'] == True
            and var['d'] == (1, 2, 3, 4),
            "variables could not be correctly returned in get_simple_variables"
        )

    def test_get_pytorch_vars(self):
        var = torch.tensor(1)
        self.experiment.pyt_var = var

        pyt_vars = self.experiment.get_pytorch_variables()

        self.assertTrue(
            "pyt_var" in pyt_vars,
            "Name for variable could not be correctly returned in "
            "get_pytorch_variables")
        self.assertTrue(
            pyt_vars['pyt_var'] == var,
            "Variable could not be correctly returned in  "
            "get_pytorch_variables")

    def test_save_results(self):
        self.experiment.add_result(0, "test")
        self.experiment.save_results("results-test.json")

        self.assertTrue(
            os.path.exists(
                os.path.join(self.experiment.elog.result_dir,
                             "results-test.json")),
            "result file could not be stored")

        with open(
                os.path.join(self.experiment.elog.result_dir,
                             "results-test.json"), "r") as f:
            content = f.read()
            self.assertTrue("test" in content and "0" in content,
                            "results content not sucessfully saved")

    def test_save_and_load_checkpoints(self):
        net = Net()
        optim = torch.optim.Adam(params=net.parameters())
        ptvar = torch.tensor(1)
        svar = "test"
        lvar = (1, 2, 3, 4)

        self.experiment.net = net
        self.experiment.optim = optim
        self.experiment.ptvar = ptvar
        self.experiment.svar = svar
        self.experiment.lvar = lvar
        self.experiment.results["test"] = 1

        self.experiment.save_checkpoint(name="test_checkpoint")

        self.experiment.net = Net()
        self.experiment.optim = torch.optim.Adam(params=Net().parameters(),
                                                 lr=1)
        self.experiment.ptvar = torch.tensor(0)
        self.experiment.svar = "test2"
        self.experiment.lvar = (-1, -2)
        self.experiment.results["test"] = 0

        self.experiment.load_checkpoint(name="test_checkpoint")

        self.assertTrue((list(self.experiment.net.parameters())[0] -
                         list(net.parameters())[0]).sum().item() < 0.00001,
                        "Net could not be restored from checkpoint")
        self.assertTrue(
            self.experiment.optim.state_dict()['param_groups'][0]['lr'] ==
            optim.state_dict()['param_groups'][0]['lr'],
            "optim could not be restored from checkpoint")
        self.assertTrue(self.experiment.ptvar == ptvar.item(),
                        "ptvar could not be restored from checkpoint")
        self.assertTrue(self.experiment.svar == svar,
                        "svar could not be restored from checkpoint")
        self.assertTrue(self.experiment.lvar == lvar,
                        "lvar could not be restored from checkpoint")
        self.assertTrue(self.experiment.results["test"] == 1,
                        "Results could not be restored from checkpoint")

    def test_add_results(self):
        self.experiment.add_result(name="test", value=1)

        self.assertTrue(self.experiment.results["test"] == 1,
                        "Result was not added")
        self.assertTrue(
            self.experiment.get_result("test") == 1, "Result was not added")

        self.assertTrue(
            os.path.exists(
                os.path.join(self.experiment.elog.result_dir,
                             "results-log.json")),
            "result file could not be stored")

        with open(
                os.path.join(self.experiment.elog.result_dir,
                             "results-log.json"), "r") as f:
            content = f.read()
        self.assertTrue("test" in content and "1" in content,
                        "results content not sucessfully saved")

    def test_save_tmp_results(self):
        self.experiment.add_result(name="test", value=1)
        self.experiment.run()

        self.assertTrue(
            os.path.exists(
                os.path.join(self.experiment.elog.result_dir, "results.json")),
            "result file could not be stored")

        with open(
                os.path.join(self.experiment.elog.result_dir, "results.json"),
                "r") as f:
            content = f.read()
        self.assertTrue("test" in content and "1" in content,
                        "results content not sucessfully temporarily saved")

    def test_save_tmp_checkpoint(self):
        self.experiment.test_var = "test"
        self.experiment.run()

        time.sleep(5)

        self.assertTrue(
            os.path.exists(
                os.path.join(self.experiment.elog.checkpoint_dir,
                             "checkpoint_current.pth.tar")),
            "Temp Checkpoint file could not be stored")
        self.assertTrue(
            os.path.exists(
                os.path.join(self.experiment.elog.checkpoint_dir,
                             "checkpoint_last.pth.tar")),
            "Last Checkpoint file could not be stored")

        self.experiment.test_var = "test2"
        self.experiment.load_checkpoint("checkpoint_current")
        self.assertTrue(self.experiment.test_var == "test",
                        "Temp Checkpoint file loading could not sucessful")
        self.experiment.test_var = "test2"
        self.experiment.load_checkpoint("checkpoint_last")
        self.assertTrue(self.experiment.test_var == "test",
                        "Temp Checkpoint file loading could not sucessful")

    def test_save_exp_info(self):
        self.experiment.run()

        with open(os.path.join(self.experiment.elog.config_dir, "exp.json"),
                  "r") as f:
            exp_info = json.load(f)

        self.assertTrue(exp_info['epoch'] == 10,
                        "Epoch not sucessfully stored in exp info")
        self.assertTrue(exp_info['name'] == 'test_experiment',
                        "Name not sucessfully stored in exp info")
        self.assertTrue(
            exp_info['state'] == 'Trained',
            "State not sucessfully stored as 'Trained' in exp info")

    def test_print(self):
        self.experiment.print("0o zD 0o")
        with open(os.path.join(self.experiment.elog.log_dir, "default.log"),
                  "r") as f:
            content = f.read()
        self.assertTrue("0o zD 0o" in content, "Print not sucessfully saved")

    def test_resume(self):
        # TODO

        self.cntr = []
        self.experiment.train = lambda epoch: self.cntr.append(0)
        self.experiment.run()
        self.assertTrue(
            len(self.cntr) == 10,
            "Did not call train for the right number of epochs")

        exp2 = PytorchExperiment(name="test-exp2",
                                 base_dir=self.test_dir,
                                 resume=self.experiment.elog.work_dir)
        exp3 = PytorchExperiment(name="test-exp2",
                                 base_dir=self.test_dir,
                                 resume=self.experiment.elog.work_dir,
                                 resume_reset_epochs=False)

        exp2.prepare_resume()
        exp3.prepare_resume()

        self.assertTrue(exp2.exp_name == 'test_experiment',
                        "Did not restore exp_name")
        self.assertTrue(exp2._epoch_idx == 0, "Did not reset epochs")
        self.assertTrue(exp3._epoch_idx == 10, "Did reset epochs")

        exp2.train = lambda epoch: self.cntr.append(0)
        exp3.train = lambda epoch: self.cntr.append(0)
        exp2.run()
        self.assertTrue(len(self.cntr) == 20, "Did call not train for exp2")
        exp3.run()
        self.assertTrue(len(self.cntr) == 20, "Did call train for exp3")

        exp2._exp_state = "Ended"
        exp3._exp_state = "Ended"