def test_linearCosine_decay_pytorch(self): """test creating an optimizer with a linear cosine decay to the learning rate""" rate = optimizers.LinearCosineDecay(initial_rate=0.1, decay_steps=10000) opt = optimizers.Adam(learning_rate=rate) params = [torch.nn.Parameter(torch.Tensor([1.0]))] torchopt = opt._create_pytorch_optimizer(params) schedule = rate._create_pytorch_schedule(torchopt)
def test_linearCosine_decay_tf(self): """test creating an optimizer with a linear cosine decay to the learning rate""" rate = optimizers.LinearCosineDecay(initial_rate=0.1, decay_steps=10000) opt = optimizers.Adam(learning_rate=rate) global_step = tf.Variable(0) tfopt = opt._create_tf_optimizer(global_step)
def test_linearCosine_decay_jax(self): """test creating an optimizer with a linear cosine decay to the learning rate""" import optax rate = optimizers.LinearCosineDecay(initial_rate=0.1, decay_steps=10000) opt = optimizers.Adam(learning_rate=rate) jaxopt = opt._create_jax_optimizer() assert isinstance(jaxopt, optax.GradientTransformation)