def test_basic(self): self.optimizer.param_groups[0]['x'] = 0 extension = extensions.ExponentialShift('x', self.rate, init=self.init, target=self.target) self._run_trainer(extension, self.expect)
def test_with_optimizer(self): optimizer = mock.Mock() optimizer.param_groups = [{'x': 0}] extension = extensions.ExponentialShift('x', self.rate, init=self.init, target=self.target, optimizer=optimizer) self._run_trainer(extension, self.expect, optimizer)
def setUp(self): self.optimizer = mock.MagicMock() self.optimizer.param_groups = [{'x': None}] self.extension = extensions.ExponentialShift('x', self.rate, self.init, self.target, self.optimizer) self.interval = 4 self.expect = [e for e in self.expect for _ in range(self.interval)] self.trigger = util.get_trigger((self.interval, 'iteration')) self.trainer = testing.get_trainer_with_mock_updater(self.trigger) self.trainer.updater.get_optimizer.return_value = self.optimizer
def test_resume(self): new_optimizer = mock.Mock() new_optimizer.param_groups = [{'x': None}] new_extension = extensions.ExponentialShift('x', self.rate, self.init, self.target, new_optimizer) self.trainer.extend(self.extension) self.trainer.run() new_trainer = testing.get_trainer_with_mock_updater((3, 'iteration')) new_trainer.extend(new_extension) testing.save_and_load_pth(self.trainer, new_trainer) new_extension.initialize(new_trainer) self.assertEqual(new_optimizer.param_groups[0]['x'], self.optimizer.param_groups[0]['x']) self.assertIsInstance(new_optimizer.param_groups[0]['x'], float)
def test_negative_rate(self): with self.assertRaises(ValueError): extensions.ExponentialShift('x', -1.0)