def test_gradient_descent(self): """Test creating a Gradient Descent optimizer.""" opt = optimizers.GradientDescent(learning_rate=0.01) with self.session() as sess: global_step = tf.Variable(0) tfopt = opt._create_optimizer(global_step) assert isinstance(tfopt, tf.train.GradientDescentOptimizer)
def test_gradient_descent_pytorch(self): """Test creating a Gradient Descent optimizer.""" opt = optimizers.GradientDescent(learning_rate=0.01) params = [torch.nn.Parameter(torch.Tensor([1.0]))] torchopt = opt._create_pytorch_optimizer(params) assert isinstance(torchopt, torch.optim.SGD)
def test_gradient_descent_tf(self): """Test creating a Gradient Descent optimizer.""" opt = optimizers.GradientDescent(learning_rate=0.01) global_step = tf.Variable(0) tfopt = opt._create_tf_optimizer(global_step) assert isinstance(tfopt, tf.keras.optimizers.SGD)
def test_gradient_descent_jax(self): """Test creating an Gradient Descent Optimizer.""" import optax opt = optimizers.GradientDescent(learning_rate=0.01) jaxopt = opt._create_jax_optimizer() assert isinstance(jaxopt, optax.GradientTransformation)