def test_suppress_automatic_save(self, patch_save_model): model_name = "test_model" model1 = RNNModel( 12, "RNN", 10, 10, model_name=model_name, work_dir=self.temp_work_dir, save_checkpoints=False, ) model2 = RNNModel( 12, "RNN", 10, 10, model_name=model_name, work_dir=self.temp_work_dir, force_reset=True, save_checkpoints=False, ) model1.fit(self.series, epochs=1) model2.fit(self.series, epochs=1) model1.predict(n=1) model2.predict(n=2) patch_save_model.assert_not_called() model1.save_model( path=os.path.join(self.temp_work_dir, model_name)) patch_save_model.assert_called()
def test_prediction_loaded_custom_trainer(self): """validate manual save with automatic save files by comparing output between the two""" auto_name = "test_save_automatic" model = RNNModel( 12, "RNN", 10, 10, model_name=auto_name, work_dir=self.temp_work_dir, save_checkpoints=True, random_state=42, ) # fit model with custom trainer trainer = pl.Trainer( max_epochs=1, enable_checkpointing=True, logger=False, callbacks=model.trainer_params["callbacks"], precision=32, ) model.fit(self.series, trainer=trainer) # load automatically saved model with manual load_model() and load_from_checkpoint() model_loaded = RNNModel.load_from_checkpoint( model_name=auto_name, work_dir=self.temp_work_dir, best=False ) # compare prediction of loaded model with original model self.assertEqual(model.predict(n=4), model_loaded.predict(n=4))
def test_lr_schedulers(self): lr_schedulers = [ (torch.optim.lr_scheduler.StepLR, { "step_size": 10 }), ( torch.optim.lr_scheduler.ReduceLROnPlateau, { "threshold": 0.001, "monitor": "train_loss" }, ), (torch.optim.lr_scheduler.ExponentialLR, { "gamma": 0.09 }), ] for lr_scheduler_cls, lr_scheduler_kwargs in lr_schedulers: model = RNNModel( 12, "RNN", 10, 10, lr_scheduler_cls=lr_scheduler_cls, lr_scheduler_kwargs=lr_scheduler_kwargs, ) # should not raise an error model.fit(self.series, epochs=1)
def test_create_instance_existing_model_with_name_force_fit_with_reset( self, patch_reset_model): model_name = "test_model" model1 = RNNModel( 12, "RNN", 10, 10, work_dir=self.temp_work_dir, model_name=model_name, save_checkpoints=True, ) # no exception is raised model1.fit(self.series, epochs=1) RNNModel( 12, "RNN", 10, 10, work_dir=self.temp_work_dir, model_name=model_name, save_checkpoints=True, force_reset=True, ) patch_reset_model.assert_called_once()
def test_builtin_extended_trainer(self): invalid_trainer_kwarg = {"precisionn": 32} # error will be raised at training time with self.assertRaises(TypeError): model = RNNModel( 12, "RNN", 10, 10, random_state=42, pl_trainer_kwargs=invalid_trainer_kwarg, ) model.fit(self.series, epochs=1) valid_trainer_kwargs = { "precision": 32, } # valid parameters shouldn't raise error model = RNNModel( 12, "RNN", 10, 10, random_state=42, pl_trainer_kwargs=valid_trainer_kwargs, ) model.fit(self.series, epochs=1)
def test_custom_callback(self): class CounterCallback(pl.callbacks.Callback): # counts the number of trained epochs starting from count_default def __init__(self, count_default): self.counter = count_default def on_train_epoch_end(self, *args, **kwargs): self.counter += 1 my_counter_0 = CounterCallback(count_default=0) my_counter_2 = CounterCallback(count_default=2) model = RNNModel( 12, "RNN", 10, 10, random_state=42, pl_trainer_kwargs={"callbacks": [my_counter_0, my_counter_2]}, ) # check if callbacks were added self.assertEqual(len(model.trainer_params["callbacks"]), 2) model.fit(self.series, epochs=2) self.assertEqual(my_counter_0.counter, model.epochs_trained) self.assertEqual(my_counter_2.counter, model.epochs_trained + 2) # check that callbacks don't overwrite Darts' built-in checkpointer model = RNNModel( 12, "RNN", 10, 10, random_state=42, work_dir=self.temp_work_dir, save_checkpoints=True, pl_trainer_kwargs={ "callbacks": [CounterCallback(0), CounterCallback(2)] }, ) # we expect 3 callbacks self.assertEqual(len(model.trainer_params["callbacks"]), 3) # first one is our Checkpointer self.assertTrue( isinstance( model.trainer_params["callbacks"][0], pl.callbacks.ModelCheckpoint ) ) # second and third are CounterCallbacks for i in range(1, 3): self.assertTrue( isinstance(model.trainer_params["callbacks"][i], CounterCallback) )
def test_invalid_metrics(self): torch_metrics = ["invalid"] with self.assertRaises(AttributeError): model = RNNModel(12, "RNN", 10, 10, n_epochs=1, torch_metrics=torch_metrics) model.fit(self.series)
def test_train_from_0_n_epochs_20_no_fit_epochs(self): model1 = RNNModel(12, "RNN", 10, 10, n_epochs=20, work_dir=self.temp_work_dir) model1.fit(self.series) self.assertEqual(20, model1.epochs_trained)
def test_prediction_custom_trainer(self): model = RNNModel(12, "RNN", 10, 10, random_state=42) model2 = RNNModel(12, "RNN", 10, 10, random_state=42) # fit model with custom trainer trainer = pl.Trainer(**self.trainer_params, precision=32) model.fit(self.series, trainer=trainer) # fit model with built-in trainer model2.fit(self.series, epochs=1) # both should produce identical prediction self.assertEqual(model.predict(n=4), model2.predict(n=4))
def test_train_from_10_n_epochs_20_fit_15_epochs(self): model1 = RNNModel(12, "RNN", 10, 10, n_epochs=20, work_dir=self.temp_work_dir) # simulate the case that user interrupted training with Ctrl-C after 10 epochs model1.fit(self.series, epochs=10) self.assertEqual(10, model1.epochs_trained) model1.fit(self.series, epochs=15) self.assertEqual(15, model1.epochs_trained)
def test_custom_trainer_setup(self): model = RNNModel(12, "RNN", 10, 10, random_state=42) # trainer with wrong precision should raise ValueError trainer = pl.Trainer(**self.trainer_params, precision=64) with self.assertRaises(ValueError): model.fit(self.series, trainer=trainer) # no error with correct precision trainer = pl.Trainer(**self.trainer_params, precision=32) model.fit(self.series, trainer=trainer) # check if number of epochs trained is same as trainer.max_epochs self.assertEqual(trainer.max_epochs, model.epochs_trained)
def test_metrics(self): metric = MeanAbsolutePercentageError() metric_collection = MetricCollection( [MeanAbsolutePercentageError(), MeanAbsoluteError()]) # test single metric model = RNNModel(12, "RNN", 10, 10, n_epochs=1, torch_metrics=metric) model.fit(self.series) # test metric collection model = RNNModel(12, "RNN", 10, 10, n_epochs=1, torch_metrics=metric_collection) model.fit(self.series) # test multivariate series model = RNNModel(12, "RNN", 10, 10, n_epochs=1, torch_metrics=metric) model.fit(self.multivariate_series)
def test_optimizers(self): optimizers = [ (torch.optim.Adam, { "lr": 0.001 }), (torch.optim.SGD, { "lr": 0.001 }), ] for optim_cls, optim_kwargs in optimizers: model = RNNModel( 12, "RNN", 10, 10, optimizer_cls=optim_cls, optimizer_kwargs=optim_kwargs, ) # should not raise an error model.fit(self.series, epochs=1)
def test_early_stopping(self): my_stopper = pl.callbacks.early_stopping.EarlyStopping( monitor="val_loss", stopping_threshold=1e9, ) model = RNNModel( 12, "RNN", 10, 10, nr_epochs_val_period=1, random_state=42, pl_trainer_kwargs={"callbacks": [my_stopper]}, ) # training should stop immediately with high stopping_threshold model.fit(self.series, val_series=self.series, epochs=100, verbose=True) self.assertEqual(model.epochs_trained, 1) # check that early stopping only takes valid monitor variables my_stopper = pl.callbacks.early_stopping.EarlyStopping( monitor="invalid_variable", stopping_threshold=1e9, ) model = RNNModel( 12, "RNN", 10, 10, nr_epochs_val_period=1, random_state=42, pl_trainer_kwargs={"callbacks": [my_stopper]}, ) with self.assertRaises(RuntimeError): model.fit(self.series, val_series=self.series, epochs=100, verbose=True)
def test_manual_save_and_load(self): """validate manual save with automatic save files by comparing output between the two""" manual_name = "test_save_manual" auto_name = "test_save_automatic" model_manual_save = RNNModel( 12, "RNN", 10, 10, model_name=manual_name, work_dir=self.temp_work_dir, save_checkpoints=False, random_state=42, ) model_auto_save = RNNModel( 12, "RNN", 10, 10, model_name=auto_name, work_dir=self.temp_work_dir, save_checkpoints=True, random_state=42, ) model_manual_save.fit(self.series, epochs=1) model_auto_save.fit(self.series, epochs=1) model_dir = os.path.join(self.temp_work_dir) # check that file was not created with manual save self.assertFalse( os.path.exists( os.path.join(model_dir, manual_name, "checkpoints"))) # check that file was created with automatic save self.assertTrue( os.path.exists( os.path.join(model_dir, auto_name, "checkpoints"))) # create manually saved model checkpoints folder checkpoint_path_manual = os.path.join(model_dir, manual_name) os.mkdir(checkpoint_path_manual) checkpoint_file_name = "checkpoint_0.pth.tar" model_path_manual = os.path.join(checkpoint_path_manual, checkpoint_file_name) checkpoint_file_name_cpkt = "checkpoint_0_ptl-ckpt.pth.tar" model_path_manual_ckpt = os.path.join(checkpoint_path_manual, checkpoint_file_name_cpkt) # save manually saved model model_manual_save.save_model(model_path_manual) self.assertTrue(os.path.exists(model_path_manual)) # check that the PTL checkpoint path is also there self.assertTrue(os.path.exists(model_path_manual_ckpt)) # load manual save model and compare with automatic model results model_manual_save = RNNModel.load_model(model_path_manual) self.assertEqual(model_manual_save.predict(n=4), model_auto_save.predict(n=4)) # load automatically saved model with manual load_model() and load_from_checkpoint() model_auto_save1 = RNNModel.load_from_checkpoint( model_name=auto_name, work_dir=self.temp_work_dir, best=False) # compare loaded checkpoint with manual save self.assertEqual(model_manual_save.predict(n=4), model_auto_save1.predict(n=4))