Esempio n. 1
0
 def test_rmsprop(self):
     """Test creating an RMSProp Optimizer."""
     opt = optimizers.RMSProp(learning_rate=0.01)
     with self.session() as sess:
         global_step = tf.Variable(0)
         tfopt = opt._create_optimizer(global_step)
         assert isinstance(tfopt, tf.train.RMSPropOptimizer)
Esempio n. 2
0
 def test_rmsprop_pytorch(self):
     """Test creating an RMSProp Optimizer."""
     opt = optimizers.RMSProp(learning_rate=0.01)
     params = [torch.nn.Parameter(torch.Tensor([1.0]))]
     torchopt = opt._create_pytorch_optimizer(params)
     assert isinstance(torchopt, torch.optim.RMSprop)
Esempio n. 3
0
 def test_rmsprop_tf(self):
     """Test creating an RMSProp Optimizer."""
     opt = optimizers.RMSProp(learning_rate=0.01)
     global_step = tf.Variable(0)
     tfopt = opt._create_tf_optimizer(global_step)
     assert isinstance(tfopt, tf.keras.optimizers.RMSprop)
Esempio n. 4
0
 def test_rmsprop_jax(self):
     """Test creating an RMSProp Optimizer."""
     import optax
     opt = optimizers.RMSProp(learning_rate=0.01)
     jaxopt = opt._create_jax_optimizer()
     assert isinstance(jaxopt, optax.GradientTransformation)