示例#1
0
    def test_trainer_can_resume_with_lr_scheduler(self):
        lr_scheduler = ExponentialLearningRateScheduler(self.optimizer, gamma=0.5)
        trainer = Trainer(
            model=self.model,
            optimizer=self.optimizer,
            iterator=self.iterator,
            learning_rate_scheduler=lr_scheduler,
            train_dataset=self.instances,
            validation_dataset=self.instances,
            num_epochs=2,
            serialization_dir=self.TEST_DIR,
        )
        trainer.train()

        new_lr_scheduler = ExponentialLearningRateScheduler(self.optimizer, gamma=0.5)
        new_trainer = Trainer(
            model=self.model,
            optimizer=self.optimizer,
            iterator=self.iterator,
            learning_rate_scheduler=new_lr_scheduler,
            train_dataset=self.instances,
            validation_dataset=self.instances,
            num_epochs=4,
            serialization_dir=self.TEST_DIR,
        )
        epoch = new_trainer._restore_checkpoint()
        assert epoch == 2
        assert new_trainer._learning_rate_scheduler.lr_scheduler.last_epoch == 1
        new_trainer.train()
示例#2
0
 def test_trainer_can_run_with_lr_scheduler(self):
     lr_scheduler = ExponentialLearningRateScheduler(self.optimizer, gamma=0.5)
     trainer = GradientDescentTrainer(
         model=self.model,
         optimizer=self.optimizer,
         data_loader=self.data_loader,
         learning_rate_scheduler=lr_scheduler,
         validation_metric="-loss",
         validation_data_loader=self.validation_data_loader,
         num_epochs=2,
     )
     trainer.train()
示例#3
0
 def test_trainer_can_run_with_lr_scheduler(self):
     lr_scheduler = ExponentialLearningRateScheduler(self.optimizer, gamma=0.5)
     trainer = Trainer(
         model=self.model,
         optimizer=self.optimizer,
         iterator=self.iterator,
         learning_rate_scheduler=lr_scheduler,
         validation_metric="-loss",
         train_dataset=self.instances,
         validation_dataset=self.instances,
         num_epochs=2,
     )
     trainer.train()