Beispiel #1
0
        def test_prediction_loaded_custom_trainer(self):
            """validate manual save with automatic save files by comparing output between the two"""
            auto_name = "test_save_automatic"
            model = RNNModel(
                12,
                "RNN",
                10,
                10,
                model_name=auto_name,
                work_dir=self.temp_work_dir,
                save_checkpoints=True,
                random_state=42,
            )

            # fit model with custom trainer
            trainer = pl.Trainer(
                max_epochs=1,
                enable_checkpointing=True,
                logger=False,
                callbacks=model.trainer_params["callbacks"],
                precision=32,
            )
            model.fit(self.series, trainer=trainer)

            # load automatically saved model with manual load_model() and load_from_checkpoint()
            model_loaded = RNNModel.load_from_checkpoint(
                model_name=auto_name, work_dir=self.temp_work_dir, best=False
            )

            # compare prediction of loaded model with original model
            self.assertEqual(model.predict(n=4), model_loaded.predict(n=4))
Beispiel #2
0
        def test_manual_save_and_load(self):
            """validate manual save with automatic save files by comparing output between the two"""

            manual_name = "test_save_manual"
            auto_name = "test_save_automatic"
            model_manual_save = RNNModel(
                12,
                "RNN",
                10,
                10,
                model_name=manual_name,
                work_dir=self.temp_work_dir,
                save_checkpoints=False,
                random_state=42,
            )
            model_auto_save = RNNModel(
                12,
                "RNN",
                10,
                10,
                model_name=auto_name,
                work_dir=self.temp_work_dir,
                save_checkpoints=True,
                random_state=42,
            )

            model_manual_save.fit(self.series, epochs=1)
            model_auto_save.fit(self.series, epochs=1)

            model_dir = os.path.join(self.temp_work_dir)

            # check that file was not created with manual save
            self.assertFalse(
                os.path.exists(
                    os.path.join(model_dir, manual_name, "checkpoints")))
            # check that file was created with automatic save
            self.assertTrue(
                os.path.exists(
                    os.path.join(model_dir, auto_name, "checkpoints")))

            # create manually saved model checkpoints folder
            checkpoint_path_manual = os.path.join(model_dir, manual_name)
            os.mkdir(checkpoint_path_manual)

            checkpoint_file_name = "checkpoint_0.pth.tar"
            model_path_manual = os.path.join(checkpoint_path_manual,
                                             checkpoint_file_name)
            checkpoint_file_name_cpkt = "checkpoint_0_ptl-ckpt.pth.tar"
            model_path_manual_ckpt = os.path.join(checkpoint_path_manual,
                                                  checkpoint_file_name_cpkt)

            # save manually saved model
            model_manual_save.save_model(model_path_manual)
            self.assertTrue(os.path.exists(model_path_manual))

            # check that the PTL checkpoint path is also there
            self.assertTrue(os.path.exists(model_path_manual_ckpt))

            # load manual save model and compare with automatic model results
            model_manual_save = RNNModel.load_model(model_path_manual)
            self.assertEqual(model_manual_save.predict(n=4),
                             model_auto_save.predict(n=4))

            # load automatically saved model with manual load_model() and load_from_checkpoint()
            model_auto_save1 = RNNModel.load_from_checkpoint(
                model_name=auto_name, work_dir=self.temp_work_dir, best=False)

            # compare loaded checkpoint with manual save
            self.assertEqual(model_manual_save.predict(n=4),
                             model_auto_save1.predict(n=4))