def test_trace_should_keep_track_of_momentum(self):
     params = np.zeros(6)
     updates = np.ones(6)
     opt = state.init(optix.trace(0.99, False))(random.PRNGKey(0), params,
                                                updates)
     onp.testing.assert_array_equal(opt.trace, np.zeros(6))
     opt = opt.update(params, updates)
     onp.testing.assert_array_equal(opt.trace, np.ones(6))
     onp.testing.assert_array_equal(opt(params, updates), 1.99 * np.ones(6))
Exemple #2
0
 def test_trace_should_keep_track_of_momentum_with_nesterov(self):
   params = jnp.zeros(6)
   updates = jnp.ones(6)
   opt = state.init(optix.trace(0.99, True))(random.PRNGKey(0),
                                             updates, params)
   np.testing.assert_array_equal(opt.trace, jnp.zeros(6))
   opt = opt.update(updates, params)
   np.testing.assert_array_equal(opt.trace, jnp.ones(6))
   np.testing.assert_array_equal(
       opt(updates, params), (1.99 + 0.99**2) * jnp.ones(6))