def test_adagrad_tf(self): """Test creating an AdaGrad optimizer.""" opt = optimizers.AdaGrad(learning_rate=0.01) global_step = tf.Variable(0) tfopt = opt._create_tf_optimizer(global_step) assert isinstance(tfopt, tf.keras.optimizers.Adagrad)
def test_adagrad_pytorch(self): """Test creating an AdaGrad optimizer.""" opt = optimizers.AdaGrad(learning_rate=0.01) params = [torch.nn.Parameter(torch.Tensor([1.0]))] torchopt = opt._create_pytorch_optimizer(params) assert isinstance(torchopt, torch.optim.Adagrad)
def test_adagrad_jax(self): """Test creating an AdaGrad optimizer.""" import optax opt = optimizers.AdaGrad(learning_rate=0.01) jaxopt = opt._create_jax_optimizer() assert isinstance(jaxopt, optax.GradientTransformation)