def test_without_init(self):
     self.optimizer.param_groups[0]['x'] = self.init
     extension = extensions.InverseShift('x',
                                         self.gamma,
                                         self.power,
                                         target=self.target)
     self._run_trainer(extension, self.expect)
 def test_with_optimizer(self):
     optimizer = mock.Mock()
     optimizer.param_groups = [{'x': 0}]
     extension = extensions.InverseShift('x',
                                         self.gamma,
                                         self.power,
                                         init=self.init,
                                         target=self.target,
                                         optimizer=optimizer)
     self._run_trainer(extension, self.expect, optimizer)
    def setUp(self):
        self.optimizer = mock.MagicMock()
        self.optimizer.param_groups = [{'x': None}]
        self.extension = extensions.InverseShift('x', self.gamma, self.power,
                                                 self.init, self.target,
                                                 self.optimizer)

        self.interval = 4
        self.expect = [e for e in self.expect for _ in range(self.interval)]
        self.trigger = util.get_trigger((self.interval, 'iteration'))

        self.trainer = testing.get_trainer_with_mock_updater(self.trigger)
        self.trainer.updater.get_optimizer.return_value = self.optimizer
    def test_resume(self):
        new_optimizer = mock.Mock()
        new_optimizer.param_groups = [{'x': None}]
        new_extension = extensions.InverseShift('x', self.gamma, self.power,
                                                self.init, self.target,
                                                new_optimizer)

        self.trainer.extend(self.extension)
        self.trainer.run()

        new_trainer = testing.get_trainer_with_mock_updater((3, 'iteration'))
        new_trainer.extend(new_extension)
        testing.save_and_load_pth(self.trainer, new_trainer)

        new_extension.initialize(new_trainer)
        self.assertEqual(new_optimizer.param_groups[0]['x'],
                         self.optimizer.param_groups[0]['x'])
        self.assertIsInstance(new_optimizer.param_groups[0]['x'], float)
 def test_negative_rate(self):
     with self.assertRaises(ValueError):
         extensions.InverseShift('x', -1.0, 1.0)