예제 #1
0
    def test_bad_call(self, err_msg):
        """Test attribute of internal opt correctly rerouted to the internal opt.

    Args:
      err_msg: The expected error message from the scheduler bad call.
    """
        scheduler = opt.GammaBetaDecreasingStep()
        with self.assertRaisesRegexp(Exception, err_msg):  # pylint: disable=deprecated-method
            scheduler(1)
예제 #2
0
    def test_call(self, step, res):
        """Test call.

    Test that attribute of internal optimizer is correctly rerouted to the
    internal optimizer

    Args:
      step: step number to 'GammaBetaDecreasingStep' 'Scheduler'.
      res: expected result from call to 'GammaBetaDecreasingStep' 'Scheduler'.
    """
        beta = _ops.convert_to_tensor_v2(2, dtype=tf.float32)
        gamma = _ops.convert_to_tensor_v2(1, dtype=tf.float32)
        scheduler = opt.GammaBetaDecreasingStep()
        scheduler.initialize(beta, gamma)
        step = _ops.convert_to_tensor_v2(step, dtype=tf.float32)
        lr = scheduler(step)
        self.assertAllClose(lr.numpy(), res)