def test_without_init(self):
     self.optimizer.x = self.init
     extension = extensions.PolynomialShift('x',
                                            self.rate,
                                            self.max_count,
                                            target=self.target)
     self._run_trainer(extension, self.expect)
 def test_with_optimizer(self):
     optimizer = mock.Mock()
     optimizer.x = 0
     extension = extensions.PolynomialShift('x',
                                            self.rate,
                                            self.max_count,
                                            init=self.init,
                                            target=self.target,
                                            optimizer=optimizer)
     self._run_trainer(extension, self.expect, optimizer)
    def setUp(self):
        self.optimizer = mock.MagicMock()
        self.extension = extensions.PolynomialShift('x', self.rate,
                                                    self.max_count, self.init,
                                                    self.target,
                                                    self.optimizer)

        self.interval = 4
        self.expect = [e for e in self.expect for _ in range(self.interval)]
        self.trigger = util.get_trigger((self.interval, 'iteration'))

        self.trainer = testing.get_trainer_with_mock_updater(self.trigger)
        self.trainer.updater.get_optimizer.return_value = self.optimizer
    def test_resume(self):
        new_optimizer = mock.Mock()
        new_extension = extensions.PolynomialShift('x', self.rate,
                                                   self.max_count, self.init,
                                                   self.target, new_optimizer)

        self.trainer.extend(self.extension)
        self.trainer.run()

        new_trainer = testing.get_trainer_with_mock_updater((3, 'iteration'))
        new_trainer.extend(new_extension)
        testing.save_and_load_npz(self.trainer, new_trainer)

        new_extension.initialize(new_trainer)
        self.assertEqual(new_optimizer.x, self.optimizer.x)
        self.assertIsInstance(new_optimizer.x, float)