Exemplo n.º 1
0
 def test_linearCosine_decay_pytorch(self):
   """test creating an optimizer with a linear cosine decay to the learning rate"""
   rate = optimizers.LinearCosineDecay(initial_rate=0.1, decay_steps=10000)
   opt = optimizers.Adam(learning_rate=rate)
   params = [torch.nn.Parameter(torch.Tensor([1.0]))]
   torchopt = opt._create_pytorch_optimizer(params)
   schedule = rate._create_pytorch_schedule(torchopt)
Exemplo n.º 2
0
 def test_linearCosine_decay_tf(self):
     """test creating an optimizer with a linear cosine decay to the learning rate"""
     rate = optimizers.LinearCosineDecay(initial_rate=0.1,
                                         decay_steps=10000)
     opt = optimizers.Adam(learning_rate=rate)
     global_step = tf.Variable(0)
     tfopt = opt._create_tf_optimizer(global_step)
Exemplo n.º 3
0
 def test_linearCosine_decay_jax(self):
     """test creating an optimizer with a linear cosine decay to the learning rate"""
     import optax
     rate = optimizers.LinearCosineDecay(initial_rate=0.1,
                                         decay_steps=10000)
     opt = optimizers.Adam(learning_rate=rate)
     jaxopt = opt._create_jax_optimizer()
     assert isinstance(jaxopt, optax.GradientTransformation)