Ejemplo n.º 1
0
 def test_exponential_decay_tf(self):
   """Test creating an optimizer with an exponentially decaying learning rate."""
   rate = optimizers.ExponentialDecay(
       initial_rate=0.001, decay_rate=0.99, decay_steps=10000)
   opt = optimizers.Adam(learning_rate=rate)
   global_step = tf.Variable(0)
   tfopt = opt._create_tf_optimizer(global_step)
Ejemplo n.º 2
0
 def test_exponential_decay_pytorch(self):
   """Test creating an optimizer with an exponentially decaying learning rate."""
   rate = optimizers.ExponentialDecay(
       initial_rate=0.001, decay_rate=0.99, decay_steps=10000)
   opt = optimizers.Adam(learning_rate=rate)
   params = [torch.nn.Parameter(torch.Tensor([1.0]))]
   torchopt = opt._create_pytorch_optimizer(params)
   schedule = rate._create_pytorch_schedule(torchopt)
Ejemplo n.º 3
0
 def test_exponential_decay_jax(self):
     """Test creating an optimizer with an exponentially decaying learning rate."""
     import optax
     rate = optimizers.ExponentialDecay(initial_rate=0.001,
                                        decay_rate=0.99,
                                        decay_steps=10000)
     opt = optimizers.Adam(learning_rate=rate)
     jaxopt = opt._create_jax_optimizer()
     assert isinstance(jaxopt, optax.GradientTransformation)