コード例 #1
0
 def test_basic(self):
     self.optimizer.param_groups[0]['x'] = 0
     extension = extensions.ExponentialShift('x',
                                             self.rate,
                                             init=self.init,
                                             target=self.target)
     self._run_trainer(extension, self.expect)
コード例 #2
0
 def test_with_optimizer(self):
     optimizer = mock.Mock()
     optimizer.param_groups = [{'x': 0}]
     extension = extensions.ExponentialShift('x',
                                             self.rate,
                                             init=self.init,
                                             target=self.target,
                                             optimizer=optimizer)
     self._run_trainer(extension, self.expect, optimizer)
コード例 #3
0
    def setUp(self):
        self.optimizer = mock.MagicMock()
        self.optimizer.param_groups = [{'x': None}]
        self.extension = extensions.ExponentialShift('x', self.rate, self.init,
                                                     self.target,
                                                     self.optimizer)

        self.interval = 4
        self.expect = [e for e in self.expect for _ in range(self.interval)]
        self.trigger = util.get_trigger((self.interval, 'iteration'))

        self.trainer = testing.get_trainer_with_mock_updater(self.trigger)
        self.trainer.updater.get_optimizer.return_value = self.optimizer
コード例 #4
0
    def test_resume(self):
        new_optimizer = mock.Mock()
        new_optimizer.param_groups = [{'x': None}]
        new_extension = extensions.ExponentialShift('x', self.rate, self.init,
                                                    self.target, new_optimizer)

        self.trainer.extend(self.extension)
        self.trainer.run()

        new_trainer = testing.get_trainer_with_mock_updater((3, 'iteration'))
        new_trainer.extend(new_extension)
        testing.save_and_load_pth(self.trainer, new_trainer)

        new_extension.initialize(new_trainer)
        self.assertEqual(new_optimizer.param_groups[0]['x'],
                         self.optimizer.param_groups[0]['x'])
        self.assertIsInstance(new_optimizer.param_groups[0]['x'], float)
コード例 #5
0
 def test_negative_rate(self):
     with self.assertRaises(ValueError):
         extensions.ExponentialShift('x', -1.0)