Esempio n. 1
0
    def test_trainer_can_resume_with_lr_scheduler(self):
        lr_scheduler = CosineWithRestarts(self.optimizer, t_initial=5)
        trainer = Trainer(
            model=self.model,
            optimizer=self.optimizer,
            data_loader=self.data_loader,
            learning_rate_scheduler=lr_scheduler,
            validation_data_loader=self.validation_data_loader,
            num_epochs=2,
            serialization_dir=self.TEST_DIR,
        )
        trainer.train()

        new_lr_scheduler = CosineWithRestarts(self.optimizer, t_initial=5)
        new_trainer = Trainer(
            model=self.model,
            optimizer=self.optimizer,
            data_loader=self.data_loader,
            learning_rate_scheduler=new_lr_scheduler,
            validation_data_loader=self.validation_data_loader,
            num_epochs=4,
            serialization_dir=self.TEST_DIR,
        )
        epoch = new_trainer._restore_checkpoint()
        assert epoch == 2
        assert new_trainer._learning_rate_scheduler.last_epoch == 1
        new_trainer.train()
Esempio n. 2
0
    def test_restoring_works_with_older_checkpointing(self):
        trainer = Trainer(
            self.model,
            self.optimizer,
            self.data_loader,
            validation_data_loader=self.validation_data_loader,
            num_epochs=3,
            serialization_dir=self.TEST_DIR,
        )
        trainer.train()

        for index in range(3):
            path = str(self.TEST_DIR /
                       "training_state_epoch_{}.th".format(index))
            state = torch.load(path)
            state.pop("metric_tracker")
            state.pop("batch_num_total")
            state["val_metric_per_epoch"] = [0.4, 0.1, 0.8]
            torch.save(state, path)

        next_epoch = trainer._restore_checkpoint()
        best_epoch = trainer._metric_tracker.best_epoch

        # Loss decreases in 3 epochs, but because we hard fed the val metrics as above:
        assert next_epoch == 3
        assert best_epoch == 1
        assert trainer._metric_tracker._best_so_far == 0.1
        assert trainer._metric_tracker._epochs_with_no_improvement == 1
Esempio n. 3
0
    def test_trainer_can_resume_training(self):
        trainer = Trainer(
            self.model,
            self.optimizer,
            self.data_loader,
            validation_data_loader=self.validation_data_loader,
            num_epochs=1,
            serialization_dir=self.TEST_DIR,
        )
        trainer.train()
        new_trainer = Trainer(
            self.model,
            self.optimizer,
            self.data_loader,
            validation_data_loader=self.validation_data_loader,
            num_epochs=3,
            serialization_dir=self.TEST_DIR,
        )

        epoch = new_trainer._restore_checkpoint()
        assert epoch == 1

        tracker = trainer._metric_tracker
        assert tracker.is_best_so_far()
        assert tracker._best_so_far is not None

        new_trainer.train()
Esempio n. 4
0
    def test_trainer_can_resume_training_for_exponential_moving_average(self):
        moving_average = ExponentialMovingAverage(
            self.model.named_parameters())

        trainer = Trainer(
            self.model,
            self.optimizer,
            self.data_loader,
            validation_data_loader=self.validation_data_loader,
            num_epochs=1,
            serialization_dir=self.TEST_DIR,
            moving_average=moving_average,
        )
        trainer.train()

        new_moving_average = ExponentialMovingAverage(
            self.model.named_parameters())
        new_trainer = Trainer(
            self.model,
            self.optimizer,
            self.data_loader,
            validation_data_loader=self.validation_data_loader,
            num_epochs=3,
            serialization_dir=self.TEST_DIR,
            moving_average=new_moving_average,
        )

        epoch = new_trainer._restore_checkpoint()
        assert epoch == 1

        tracker = trainer._metric_tracker
        assert tracker.is_best_so_far()
        assert tracker._best_so_far is not None

        new_trainer.train()
Esempio n. 5
0
    def test_trainer_can_resume_with_lr_scheduler(self):
        lr_scheduler = LearningRateScheduler.from_params(
            self.optimizer, Params({
                "type": "exponential",
                "gamma": 0.5
            }))
        trainer = Trainer(model=self.model,
                          optimizer=self.optimizer,
                          iterator=self.iterator,
                          learning_rate_scheduler=lr_scheduler,
                          train_dataset=self.instances,
                          validation_dataset=self.instances,
                          num_epochs=2,
                          serialization_dir=self.TEST_DIR)
        trainer.train()

        new_lr_scheduler = LearningRateScheduler.from_params(
            self.optimizer, Params({
                "type": "exponential",
                "gamma": 0.5
            }))
        new_trainer = Trainer(model=self.model,
                              optimizer=self.optimizer,
                              iterator=self.iterator,
                              learning_rate_scheduler=new_lr_scheduler,
                              train_dataset=self.instances,
                              validation_dataset=self.instances,
                              num_epochs=4,
                              serialization_dir=self.TEST_DIR)
        epoch = new_trainer._restore_checkpoint()
        assert epoch == 2
        assert new_trainer._learning_rate_scheduler.lr_scheduler.last_epoch == 1
        new_trainer.train()
Esempio n. 6
0
    def test_trainer_can_resume_training_for_exponential_moving_average(self):
        moving_average = ExponentialMovingAverage(
            self.model.named_parameters())

        trainer = Trainer(self.model,
                          self.optimizer,
                          self.iterator,
                          self.instances,
                          validation_dataset=self.instances,
                          num_epochs=1,
                          serialization_dir=self.TEST_DIR,
                          moving_average=moving_average)
        trainer.train()

        new_moving_average = ExponentialMovingAverage(
            self.model.named_parameters())
        new_trainer = Trainer(self.model,
                              self.optimizer,
                              self.iterator,
                              self.instances,
                              validation_dataset=self.instances,
                              num_epochs=3,
                              serialization_dir=self.TEST_DIR,
                              moving_average=new_moving_average)

        epoch = new_trainer._restore_checkpoint()  # pylint: disable=protected-access
        assert epoch == 1

        tracker = trainer._metric_tracker  # pylint: disable=protected-access
        assert tracker.is_best_so_far()
        assert tracker._best_so_far is not None  # pylint: disable=protected-access

        new_trainer.train()
Esempio n. 7
0
    def test_trainer_can_resume_training(self):
        trainer = Trainer(self.model,
                          self.optimizer,
                          self.iterator,
                          self.instances,
                          validation_dataset=self.instances,
                          num_epochs=1,
                          serialization_dir=self.TEST_DIR)
        trainer.train()
        new_trainer = Trainer(self.model,
                              self.optimizer,
                              self.iterator,
                              self.instances,
                              validation_dataset=self.instances,
                              num_epochs=3,
                              serialization_dir=self.TEST_DIR)

        epoch = new_trainer._restore_checkpoint()  # pylint: disable=protected-access
        assert epoch == 1

        tracker = trainer._metric_tracker  # pylint: disable=protected-access
        assert tracker.is_best_so_far()
        assert tracker._best_so_far is not None  # pylint: disable=protected-access

        new_trainer.train()
Esempio n. 8
0
    def test_trainer_can_resume_with_lr_scheduler(self):
        lr_scheduler = ExponentialLearningRateScheduler(self.optimizer, gamma=0.5)
        trainer = Trainer(
            model=self.model,
            optimizer=self.optimizer,
            iterator=self.iterator,
            learning_rate_scheduler=lr_scheduler,
            train_dataset=self.instances,
            validation_dataset=self.instances,
            num_epochs=2,
            serialization_dir=self.TEST_DIR,
        )
        trainer.train()

        new_lr_scheduler = ExponentialLearningRateScheduler(self.optimizer, gamma=0.5)
        new_trainer = Trainer(
            model=self.model,
            optimizer=self.optimizer,
            iterator=self.iterator,
            learning_rate_scheduler=new_lr_scheduler,
            train_dataset=self.instances,
            validation_dataset=self.instances,
            num_epochs=4,
            serialization_dir=self.TEST_DIR,
        )
        epoch = new_trainer._restore_checkpoint()
        assert epoch == 2
        assert new_trainer._learning_rate_scheduler.lr_scheduler.last_epoch == 1
        new_trainer.train()
Esempio n. 9
0
    def test_trainer_can_run_and_resume_with_momentum_scheduler(self):
        scheduler = MomentumScheduler.from_params(
            optimizer=self.optimizer,
            params=Params({"type": "inverted_triangular", "cool_down": 2, "warm_up": 2}),
        )
        trainer = Trainer(
            model=self.model,
            optimizer=self.optimizer,
            data_loader=self.data_loader,
            momentum_scheduler=scheduler,
            validation_metric="-loss",
            validation_data_loader=self.validation_data_loader,
            num_epochs=4,
            serialization_dir=self.TEST_DIR,
        )
        trainer.train()

        new_scheduler = MomentumScheduler.from_params(
            optimizer=self.optimizer,
            params=Params({"type": "inverted_triangular", "cool_down": 2, "warm_up": 2}),
        )
        new_trainer = Trainer(
            model=self.model,
            optimizer=self.optimizer,
            data_loader=self.data_loader,
            momentum_scheduler=new_scheduler,
            validation_metric="-loss",
            validation_data_loader=self.validation_data_loader,
            num_epochs=6,
            serialization_dir=self.TEST_DIR,
        )
        epoch = new_trainer._restore_checkpoint()
        assert epoch == 4
        assert new_trainer._momentum_scheduler.last_epoch == 3
        new_trainer.train()
Esempio n. 10
0
    def test_trainer_saves_models_at_specified_interval(self):
        iterator = BasicIterator(batch_size=4)
        iterator.index_with(self.vocab)

        trainer = Trainer(
            self.model,
            self.optimizer,
            iterator,
            self.instances,
            num_epochs=2,
            serialization_dir=self.TEST_DIR,
            model_save_interval=0.0001,
        )

        trainer.train()

        # Now check the serialized files for models saved during the epoch.
        prefix = "model_state_epoch_*"
        file_names = sorted(glob.glob(os.path.join(self.TEST_DIR, prefix)))
        epochs = [
            re.search(r"_([0-9\.\-]+)\.th", fname).group(1)
            for fname in file_names
        ]
        # We should have checkpoints at the end of each epoch and during each, e.g.
        # [0.timestamp, 0, 1.timestamp, 1]
        assert len(epochs) == 4
        assert epochs[3] == "1"
        assert "." in epochs[0]

        # Now make certain we can restore from timestamped checkpoint.
        # To do so, remove the checkpoint from the end of epoch 1&2, so
        # that we are forced to restore from the timestamped checkpoints.
        for k in range(2):
            os.remove(
                os.path.join(self.TEST_DIR,
                             "model_state_epoch_{}.th".format(k)))
            os.remove(
                os.path.join(self.TEST_DIR,
                             "training_state_epoch_{}.th".format(k)))
        os.remove(os.path.join(self.TEST_DIR, "best.th"))

        restore_trainer = Trainer(
            self.model,
            self.optimizer,
            self.iterator,
            self.instances,
            num_epochs=2,
            serialization_dir=self.TEST_DIR,
            model_save_interval=0.0001,
        )
        epoch = restore_trainer._restore_checkpoint()
        assert epoch == 2
        # One batch per epoch.
        assert restore_trainer._batch_num_total == 2
Esempio n. 11
0
    def test_trainer_saves_and_loads_best_validation_metrics_correctly_2(self):
        # Use -loss and run 1 epoch of original-training, and one of restored-training
        # Run 1 epoch of original training.
        trainer = Trainer(
            self.model,
            self.optimizer,
            self.iterator,
            self.instances,
            validation_dataset=self.instances,
            validation_metric="+loss",
            num_epochs=1,
            serialization_dir=self.TEST_DIR,
        )
        trainer.train()

        _ = trainer._restore_checkpoint()
        best_epoch_1 = trainer._metric_tracker.best_epoch
        best_validation_metrics_epoch_1 = trainer._metric_tracker.best_epoch_metrics
        # best_validation_metrics_epoch_1: {'accuracy': 0.75, 'accuracy3': 1.0, 'loss': 0.6243013441562653}
        assert isinstance(best_validation_metrics_epoch_1, dict)
        assert "loss" in best_validation_metrics_epoch_1

        # Run 1 more epoch of restored training.
        restore_trainer = Trainer(
            self.model,
            self.optimizer,
            self.iterator,
            self.instances,
            validation_dataset=self.instances,
            validation_metric="+loss",
            num_epochs=2,
            serialization_dir=self.TEST_DIR,
        )
        restore_trainer.train()
        _ = restore_trainer._restore_checkpoint()
        best_epoch_2 = restore_trainer._metric_tracker.best_epoch
        best_validation_metrics_epoch_2 = restore_trainer._metric_tracker.best_epoch_metrics

        # Because of using +loss, 2nd epoch won't be better than 1st. So best val metrics should be same.
        assert best_epoch_1 == best_epoch_2 == 0
        assert best_validation_metrics_epoch_2 == best_validation_metrics_epoch_1