コード例 #1
0
 def test_gradient_descent(self):
     """Test creating a Gradient Descent optimizer."""
     opt = optimizers.GradientDescent(learning_rate=0.01)
     with self.session() as sess:
         global_step = tf.Variable(0)
         tfopt = opt._create_optimizer(global_step)
         assert isinstance(tfopt, tf.train.GradientDescentOptimizer)
コード例 #2
0
 def test_gradient_descent_pytorch(self):
     """Test creating a Gradient Descent optimizer."""
     opt = optimizers.GradientDescent(learning_rate=0.01)
     params = [torch.nn.Parameter(torch.Tensor([1.0]))]
     torchopt = opt._create_pytorch_optimizer(params)
     assert isinstance(torchopt, torch.optim.SGD)
コード例 #3
0
 def test_gradient_descent_tf(self):
     """Test creating a Gradient Descent optimizer."""
     opt = optimizers.GradientDescent(learning_rate=0.01)
     global_step = tf.Variable(0)
     tfopt = opt._create_tf_optimizer(global_step)
     assert isinstance(tfopt, tf.keras.optimizers.SGD)
コード例 #4
0
 def test_gradient_descent_jax(self):
     """Test creating an Gradient Descent Optimizer."""
     import optax
     opt = optimizers.GradientDescent(learning_rate=0.01)
     jaxopt = opt._create_jax_optimizer()
     assert isinstance(jaxopt, optax.GradientTransformation)