def test_prediction_loaded_custom_trainer(self): """validate manual save with automatic save files by comparing output between the two""" auto_name = "test_save_automatic" model = RNNModel( 12, "RNN", 10, 10, model_name=auto_name, work_dir=self.temp_work_dir, save_checkpoints=True, random_state=42, ) # fit model with custom trainer trainer = pl.Trainer( max_epochs=1, enable_checkpointing=True, logger=False, callbacks=model.trainer_params["callbacks"], precision=32, ) model.fit(self.series, trainer=trainer) # load automatically saved model with manual load_model() and load_from_checkpoint() model_loaded = RNNModel.load_from_checkpoint( model_name=auto_name, work_dir=self.temp_work_dir, best=False ) # compare prediction of loaded model with original model self.assertEqual(model.predict(n=4), model_loaded.predict(n=4))
def test_manual_save_and_load(self): """validate manual save with automatic save files by comparing output between the two""" manual_name = "test_save_manual" auto_name = "test_save_automatic" model_manual_save = RNNModel( 12, "RNN", 10, 10, model_name=manual_name, work_dir=self.temp_work_dir, save_checkpoints=False, random_state=42, ) model_auto_save = RNNModel( 12, "RNN", 10, 10, model_name=auto_name, work_dir=self.temp_work_dir, save_checkpoints=True, random_state=42, ) model_manual_save.fit(self.series, epochs=1) model_auto_save.fit(self.series, epochs=1) model_dir = os.path.join(self.temp_work_dir) # check that file was not created with manual save self.assertFalse( os.path.exists( os.path.join(model_dir, manual_name, "checkpoints"))) # check that file was created with automatic save self.assertTrue( os.path.exists( os.path.join(model_dir, auto_name, "checkpoints"))) # create manually saved model checkpoints folder checkpoint_path_manual = os.path.join(model_dir, manual_name) os.mkdir(checkpoint_path_manual) checkpoint_file_name = "checkpoint_0.pth.tar" model_path_manual = os.path.join(checkpoint_path_manual, checkpoint_file_name) checkpoint_file_name_cpkt = "checkpoint_0_ptl-ckpt.pth.tar" model_path_manual_ckpt = os.path.join(checkpoint_path_manual, checkpoint_file_name_cpkt) # save manually saved model model_manual_save.save_model(model_path_manual) self.assertTrue(os.path.exists(model_path_manual)) # check that the PTL checkpoint path is also there self.assertTrue(os.path.exists(model_path_manual_ckpt)) # load manual save model and compare with automatic model results model_manual_save = RNNModel.load_model(model_path_manual) self.assertEqual(model_manual_save.predict(n=4), model_auto_save.predict(n=4)) # load automatically saved model with manual load_model() and load_from_checkpoint() model_auto_save1 = RNNModel.load_from_checkpoint( model_name=auto_name, work_dir=self.temp_work_dir, best=False) # compare loaded checkpoint with manual save self.assertEqual(model_manual_save.predict(n=4), model_auto_save1.predict(n=4))