def test_init(self): learning_rate = 1. adadp = ADADP(learning_rate, 1.) value = self.template i, (x, lr, x_stepped, x_prev) = adadp.init(value) self.assertEqual(0, i) self.assertTreeAllClose(value, x) self.assertEqual(learning_rate, lr) self.assertTreeAllClose(self.same_tree_with_value(self.template, 0.), x_stepped) self.assertTreeStructure(self.template, x_prev)
def test_update_step_2_with_stability_check(self): learning_rate = 1. adadp = ADADP(learning_rate, tol=5., stability_check=True) value = self.same_tree_with_value(self.template, 0.) gradient = self.same_tree_with_value(self.template, 3.) opt_state = (1, (self.same_tree_with_value(value, -0.5), learning_rate, self.same_tree_with_value(value, -1.), value)) i, (x, lr, x_stepped, x_prev) = adadp.update(gradient, opt_state) expected_lr = .9 # 0.72005267 clipped by alpha_min self.assertEqual(2, i) self.assertTreeAllClose(value, x) # update rejected self.assertTrue(jnp.allclose(expected_lr, lr))
def test_update_step_1(self): learning_rate = 1. adadp = ADADP(learning_rate, 1.) value = self.same_tree_with_value(self.template, 0.) gradient = self.same_tree_with_value(self.template, 1.) opt_state = (0, (value, learning_rate, value, value)) i, (x, lr, x_stepped, x_prev) = adadp.update(gradient, opt_state) step_result = self.same_tree_with_value(self.template, -1.) half_step_result = self.same_tree_with_value(self.template, -0.5) self.assertEqual(1, i) self.assertTreeAllClose(half_step_result, x) self.assertEqual(learning_rate, lr) self.assertTreeAllClose(step_result, x_stepped) self.assertTreeAllClose(value, x_prev)
def test_update_step_2_no_stability_check(self): learning_rate = 1. adadp = ADADP(learning_rate, tol=5., stability_check=False) value = self.same_tree_with_value(self.template, 0.) gradient = self.same_tree_with_value(self.template, 2.) opt_state = (1, (self.same_tree_with_value(value, -0.5), learning_rate, self.same_tree_with_value(value, -1.), value)) i, (x, lr, x_stepped, x_prev) = adadp.update(gradient, opt_state) two_half_step_results = self.same_tree_with_value(self.template, -1.5) expected_lr = 1.018308251 self.assertEqual(2, i) self.assertTreeAllClose(two_half_step_results, x) self.assertTrue(jnp.allclose(expected_lr, lr))