def test_unrecognized_optimizer_type(self): optimizer_type = 'unk' model = np.array([1., 2.]) with self.assertRaisesRegex(ValueError, '(?i)Unrecognized.*unk'): _ = models.init_optimizer_by_type(model, optimizer_type)
def test_init_optimizer_def_lamb(self): optimizer_type = 'lamb' model = np.array([1., 2.]) opt = models.init_optimizer_by_type(model, optimizer_type) self.assertIsInstance(opt.optimizer_def, optim.LAMB)
def test_init_optimizer_def_gradient_descent(self): optimizer_type = 'gradient_descent' model = np.array([1., 2.]) opt = models.init_optimizer_by_type(model, optimizer_type) self.assertIsInstance(opt.optimizer_def, optim.GradientDescent)
def test_init_optimizer_def_adam(self): optimizer_type = 'adam' model = np.array([1., 2.]) opt = models.init_optimizer_by_type(model, optimizer_type) self.assertIsInstance(opt.optimizer_def, optim.Adam)
def test_init_optimizer_target_adafactor(self): optimizer_type = 'adafactor' model = np.array([1., 2.]) opt = models.init_optimizer_by_type(model, optimizer_type) self.assertCountEqual(opt.target, model)