def test_bad_call(self, err_msg): """Test attribute of internal opt correctly rerouted to the internal opt. Args: err_msg: The expected error message from the scheduler bad call. """ scheduler = opt.GammaBetaDecreasingStep() with self.assertRaisesRegexp(Exception, err_msg): # pylint: disable=deprecated-method scheduler(1)
def test_call(self, step, res): """Test call. Test that attribute of internal optimizer is correctly rerouted to the internal optimizer Args: step: step number to 'GammaBetaDecreasingStep' 'Scheduler'. res: expected result from call to 'GammaBetaDecreasingStep' 'Scheduler'. """ beta = _ops.convert_to_tensor_v2(2, dtype=tf.float32) gamma = _ops.convert_to_tensor_v2(1, dtype=tf.float32) scheduler = opt.GammaBetaDecreasingStep() scheduler.initialize(beta, gamma) step = _ops.convert_to_tensor_v2(step, dtype=tf.float32) lr = scheduler(step) self.assertAllClose(lr.numpy(), res)