Пример #1
0
 def test_adamw_tf(self):
     """Test creating an AdamW optimizer."""
     opt = optimizers.AdamW(learning_rate=0.01)
     global_step = tf.Variable(0)
     tfopt = opt._create_tf_optimizer(global_step)
     assert isinstance(tfopt, tfa.optimizers.AdamW)
Пример #2
0
 def test_adamw_pytorch(self):
     """Test creating an AdamW optimizer."""
     opt = optimizers.AdamW(learning_rate=0.01)
     params = [torch.nn.Parameter(torch.Tensor([1.0]))]
     torchopt = opt._create_pytorch_optimizer(params)
     assert isinstance(torchopt, torch.optim.AdamW)
Пример #3
0
 def test_adamw_jax(self):
     """Test creating an AdamW optimizer."""
     import optax
     opt = optimizers.AdamW(learning_rate=0.01)
     jaxopt = opt._create_jax_optimizer()
     assert isinstance(jaxopt, optax.GradientTransformation)