コード例 #1
0
        def test_suppress_automatic_save(self, patch_save_model):
            model_name = "test_model"
            model1 = RNNModel(
                12,
                "RNN",
                10,
                10,
                model_name=model_name,
                work_dir=self.temp_work_dir,
                save_checkpoints=False,
            )
            model2 = RNNModel(
                12,
                "RNN",
                10,
                10,
                model_name=model_name,
                work_dir=self.temp_work_dir,
                force_reset=True,
                save_checkpoints=False,
            )

            model1.fit(self.series, epochs=1)
            model2.fit(self.series, epochs=1)

            model1.predict(n=1)
            model2.predict(n=2)

            patch_save_model.assert_not_called()

            model1.save_model(
                path=os.path.join(self.temp_work_dir, model_name))
            patch_save_model.assert_called()
コード例 #2
0
        def test_prediction_loaded_custom_trainer(self):
            """validate manual save with automatic save files by comparing output between the two"""
            auto_name = "test_save_automatic"
            model = RNNModel(
                12,
                "RNN",
                10,
                10,
                model_name=auto_name,
                work_dir=self.temp_work_dir,
                save_checkpoints=True,
                random_state=42,
            )

            # fit model with custom trainer
            trainer = pl.Trainer(
                max_epochs=1,
                enable_checkpointing=True,
                logger=False,
                callbacks=model.trainer_params["callbacks"],
                precision=32,
            )
            model.fit(self.series, trainer=trainer)

            # load automatically saved model with manual load_model() and load_from_checkpoint()
            model_loaded = RNNModel.load_from_checkpoint(
                model_name=auto_name, work_dir=self.temp_work_dir, best=False
            )

            # compare prediction of loaded model with original model
            self.assertEqual(model.predict(n=4), model_loaded.predict(n=4))
コード例 #3
0
        def test_lr_schedulers(self):

            lr_schedulers = [
                (torch.optim.lr_scheduler.StepLR, {
                    "step_size": 10
                }),
                (
                    torch.optim.lr_scheduler.ReduceLROnPlateau,
                    {
                        "threshold": 0.001,
                        "monitor": "train_loss"
                    },
                ),
                (torch.optim.lr_scheduler.ExponentialLR, {
                    "gamma": 0.09
                }),
            ]

            for lr_scheduler_cls, lr_scheduler_kwargs in lr_schedulers:
                model = RNNModel(
                    12,
                    "RNN",
                    10,
                    10,
                    lr_scheduler_cls=lr_scheduler_cls,
                    lr_scheduler_kwargs=lr_scheduler_kwargs,
                )
                # should not raise an error
                model.fit(self.series, epochs=1)
コード例 #4
0
        def test_create_instance_existing_model_with_name_force_fit_with_reset(
                self, patch_reset_model):
            model_name = "test_model"
            model1 = RNNModel(
                12,
                "RNN",
                10,
                10,
                work_dir=self.temp_work_dir,
                model_name=model_name,
                save_checkpoints=True,
            )
            # no exception is raised

            model1.fit(self.series, epochs=1)

            RNNModel(
                12,
                "RNN",
                10,
                10,
                work_dir=self.temp_work_dir,
                model_name=model_name,
                save_checkpoints=True,
                force_reset=True,
            )
            patch_reset_model.assert_called_once()
コード例 #5
0
        def test_builtin_extended_trainer(self):
            invalid_trainer_kwarg = {"precisionn": 32}

            # error will be raised at training time
            with self.assertRaises(TypeError):
                model = RNNModel(
                    12,
                    "RNN",
                    10,
                    10,
                    random_state=42,
                    pl_trainer_kwargs=invalid_trainer_kwarg,
                )
                model.fit(self.series, epochs=1)

            valid_trainer_kwargs = {
                "precision": 32,
            }

            # valid parameters shouldn't raise error
            model = RNNModel(
                12,
                "RNN",
                10,
                10,
                random_state=42,
                pl_trainer_kwargs=valid_trainer_kwargs,
            )
            model.fit(self.series, epochs=1)
コード例 #6
0
        def test_custom_callback(self):
            class CounterCallback(pl.callbacks.Callback):
                # counts the number of trained epochs starting from count_default
                def __init__(self, count_default):
                    self.counter = count_default

                def on_train_epoch_end(self, *args, **kwargs):
                    self.counter += 1

            my_counter_0 = CounterCallback(count_default=0)
            my_counter_2 = CounterCallback(count_default=2)

            model = RNNModel(
                12,
                "RNN",
                10,
                10,
                random_state=42,
                pl_trainer_kwargs={"callbacks": [my_counter_0, my_counter_2]},
            )

            # check if callbacks were added
            self.assertEqual(len(model.trainer_params["callbacks"]), 2)
            model.fit(self.series, epochs=2)

            self.assertEqual(my_counter_0.counter, model.epochs_trained)
            self.assertEqual(my_counter_2.counter, model.epochs_trained + 2)

            # check that callbacks don't overwrite Darts' built-in checkpointer
            model = RNNModel(
                12,
                "RNN",
                10,
                10,
                random_state=42,
                work_dir=self.temp_work_dir,
                save_checkpoints=True,
                pl_trainer_kwargs={
                    "callbacks": [CounterCallback(0), CounterCallback(2)]
                },
            )
            # we expect 3 callbacks
            self.assertEqual(len(model.trainer_params["callbacks"]), 3)

            # first one is our Checkpointer
            self.assertTrue(
                isinstance(
                    model.trainer_params["callbacks"][0], pl.callbacks.ModelCheckpoint
                )
            )

            # second and third are CounterCallbacks
            for i in range(1, 3):
                self.assertTrue(
                    isinstance(model.trainer_params["callbacks"][i], CounterCallback)
                )
コード例 #7
0
 def test_invalid_metrics(self):
     torch_metrics = ["invalid"]
     with self.assertRaises(AttributeError):
         model = RNNModel(12,
                          "RNN",
                          10,
                          10,
                          n_epochs=1,
                          torch_metrics=torch_metrics)
         model.fit(self.series)
コード例 #8
0
        def test_train_from_0_n_epochs_20_no_fit_epochs(self):
            model1 = RNNModel(12,
                              "RNN",
                              10,
                              10,
                              n_epochs=20,
                              work_dir=self.temp_work_dir)

            model1.fit(self.series)

            self.assertEqual(20, model1.epochs_trained)
コード例 #9
0
        def test_prediction_custom_trainer(self):
            model = RNNModel(12, "RNN", 10, 10, random_state=42)
            model2 = RNNModel(12, "RNN", 10, 10, random_state=42)

            # fit model with custom trainer
            trainer = pl.Trainer(**self.trainer_params, precision=32)
            model.fit(self.series, trainer=trainer)

            # fit model with built-in trainer
            model2.fit(self.series, epochs=1)

            # both should produce identical prediction
            self.assertEqual(model.predict(n=4), model2.predict(n=4))
コード例 #10
0
        def test_train_from_10_n_epochs_20_fit_15_epochs(self):
            model1 = RNNModel(12,
                              "RNN",
                              10,
                              10,
                              n_epochs=20,
                              work_dir=self.temp_work_dir)

            # simulate the case that user interrupted training with Ctrl-C after 10 epochs
            model1.fit(self.series, epochs=10)
            self.assertEqual(10, model1.epochs_trained)

            model1.fit(self.series, epochs=15)
            self.assertEqual(15, model1.epochs_trained)
コード例 #11
0
        def test_custom_trainer_setup(self):
            model = RNNModel(12, "RNN", 10, 10, random_state=42)

            # trainer with wrong precision should raise ValueError
            trainer = pl.Trainer(**self.trainer_params, precision=64)
            with self.assertRaises(ValueError):
                model.fit(self.series, trainer=trainer)

            # no error with correct precision
            trainer = pl.Trainer(**self.trainer_params, precision=32)
            model.fit(self.series, trainer=trainer)

            # check if number of epochs trained is same as trainer.max_epochs
            self.assertEqual(trainer.max_epochs, model.epochs_trained)
コード例 #12
0
        def test_metrics(self):
            metric = MeanAbsolutePercentageError()
            metric_collection = MetricCollection(
                [MeanAbsolutePercentageError(),
                 MeanAbsoluteError()])

            # test single metric
            model = RNNModel(12,
                             "RNN",
                             10,
                             10,
                             n_epochs=1,
                             torch_metrics=metric)
            model.fit(self.series)

            # test metric collection
            model = RNNModel(12,
                             "RNN",
                             10,
                             10,
                             n_epochs=1,
                             torch_metrics=metric_collection)
            model.fit(self.series)

            # test multivariate series
            model = RNNModel(12,
                             "RNN",
                             10,
                             10,
                             n_epochs=1,
                             torch_metrics=metric)
            model.fit(self.multivariate_series)
コード例 #13
0
        def test_optimizers(self):

            optimizers = [
                (torch.optim.Adam, {
                    "lr": 0.001
                }),
                (torch.optim.SGD, {
                    "lr": 0.001
                }),
            ]

            for optim_cls, optim_kwargs in optimizers:
                model = RNNModel(
                    12,
                    "RNN",
                    10,
                    10,
                    optimizer_cls=optim_cls,
                    optimizer_kwargs=optim_kwargs,
                )
                # should not raise an error
                model.fit(self.series, epochs=1)
コード例 #14
0
        def test_early_stopping(self):
            my_stopper = pl.callbacks.early_stopping.EarlyStopping(
                monitor="val_loss",
                stopping_threshold=1e9,
            )
            model = RNNModel(
                12,
                "RNN",
                10,
                10,
                nr_epochs_val_period=1,
                random_state=42,
                pl_trainer_kwargs={"callbacks": [my_stopper]},
            )

            # training should stop immediately with high stopping_threshold
            model.fit(self.series, val_series=self.series, epochs=100, verbose=True)
            self.assertEqual(model.epochs_trained, 1)

            # check that early stopping only takes valid monitor variables
            my_stopper = pl.callbacks.early_stopping.EarlyStopping(
                monitor="invalid_variable",
                stopping_threshold=1e9,
            )
            model = RNNModel(
                12,
                "RNN",
                10,
                10,
                nr_epochs_val_period=1,
                random_state=42,
                pl_trainer_kwargs={"callbacks": [my_stopper]},
            )

            with self.assertRaises(RuntimeError):
                model.fit(self.series, val_series=self.series, epochs=100, verbose=True)
コード例 #15
0
        def test_manual_save_and_load(self):
            """validate manual save with automatic save files by comparing output between the two"""

            manual_name = "test_save_manual"
            auto_name = "test_save_automatic"
            model_manual_save = RNNModel(
                12,
                "RNN",
                10,
                10,
                model_name=manual_name,
                work_dir=self.temp_work_dir,
                save_checkpoints=False,
                random_state=42,
            )
            model_auto_save = RNNModel(
                12,
                "RNN",
                10,
                10,
                model_name=auto_name,
                work_dir=self.temp_work_dir,
                save_checkpoints=True,
                random_state=42,
            )

            model_manual_save.fit(self.series, epochs=1)
            model_auto_save.fit(self.series, epochs=1)

            model_dir = os.path.join(self.temp_work_dir)

            # check that file was not created with manual save
            self.assertFalse(
                os.path.exists(
                    os.path.join(model_dir, manual_name, "checkpoints")))
            # check that file was created with automatic save
            self.assertTrue(
                os.path.exists(
                    os.path.join(model_dir, auto_name, "checkpoints")))

            # create manually saved model checkpoints folder
            checkpoint_path_manual = os.path.join(model_dir, manual_name)
            os.mkdir(checkpoint_path_manual)

            checkpoint_file_name = "checkpoint_0.pth.tar"
            model_path_manual = os.path.join(checkpoint_path_manual,
                                             checkpoint_file_name)
            checkpoint_file_name_cpkt = "checkpoint_0_ptl-ckpt.pth.tar"
            model_path_manual_ckpt = os.path.join(checkpoint_path_manual,
                                                  checkpoint_file_name_cpkt)

            # save manually saved model
            model_manual_save.save_model(model_path_manual)
            self.assertTrue(os.path.exists(model_path_manual))

            # check that the PTL checkpoint path is also there
            self.assertTrue(os.path.exists(model_path_manual_ckpt))

            # load manual save model and compare with automatic model results
            model_manual_save = RNNModel.load_model(model_path_manual)
            self.assertEqual(model_manual_save.predict(n=4),
                             model_auto_save.predict(n=4))

            # load automatically saved model with manual load_model() and load_from_checkpoint()
            model_auto_save1 = RNNModel.load_from_checkpoint(
                model_name=auto_name, work_dir=self.temp_work_dir, best=False)

            # compare loaded checkpoint with manual save
            self.assertEqual(model_manual_save.predict(n=4),
                             model_auto_save1.predict(n=4))