コード例 #1
0
    def test_trainer_can_resume_training_for_exponential_moving_average(self):
        moving_average = ExponentialMovingAverage(self.model.named_parameters())
        callbacks = self.default_callbacks() + [UpdateMovingAverage(moving_average)]

        trainer = CallbackTrainer(self.model,
                                  training_data=self.instances,
                                  iterator=self.iterator,
                                  optimizer=self.optimizer,
                                  num_epochs=1, serialization_dir=self.TEST_DIR,
                                  callbacks=callbacks)
        trainer.train()

        new_moving_average = ExponentialMovingAverage(self.model.named_parameters())
        new_callbacks = self.default_callbacks() + [UpdateMovingAverage(new_moving_average)]

        new_trainer = CallbackTrainer(self.model,
                                      training_data=self.instances,
                                      iterator=self.iterator,
                                      optimizer=self.optimizer,
                                      num_epochs=3, serialization_dir=self.TEST_DIR,
                                      callbacks=new_callbacks)

        new_trainer.handler.fire_event(Events.TRAINING_START)  # pylint: disable=protected-access
        assert new_trainer.epoch_number == 1

        tracker = trainer.metric_tracker  # pylint: disable=protected-access
        assert tracker.is_best_so_far()
        assert tracker._best_so_far is not None  # pylint: disable=protected-access

        new_trainer.train()
コード例 #2
0
 def test_trainer_can_run_ema_from_params(self):
     uma_params = Params({"moving_average": {"decay": 0.9999}})
     callbacks = self.default_callbacks() + [UpdateMovingAverage.from_params(uma_params, self.model)]
     trainer = CallbackTrainer(model=self.model,
                               training_data=self.instances,
                               iterator=self.iterator,
                               optimizer=self.optimizer,
                               num_epochs=2,
                               callbacks=callbacks)
     trainer.train()
コード例 #3
0
 def test_trainer_can_run_exponential_moving_average(self):
     moving_average = ExponentialMovingAverage(self.model.named_parameters(), decay=0.9999)
     callbacks = self.default_callbacks() + [UpdateMovingAverage(moving_average)]
     trainer = CallbackTrainer(model=self.model,
                               training_data=self.instances,
                               iterator=self.iterator,
                               optimizer=self.optimizer,
                               num_epochs=2,
                               callbacks=callbacks)
     trainer.train()