def test_adjoint(self): """ Test against dopri5 """ tf.compat.v1.set_random_seed(0) f, y0, t_points, _ = problems.construct_problem(TEST_DEVICE) y0 = tf.cast(y0, tf.float64) t_points = tf.cast(t_points, tf.float64) func = lambda y0, t_points: tfdiffeq.odeint(f, y0, t_points, method='dopri5') with tf.GradientTape() as tape: tape.watch(t_points) ys = func(y0, t_points) reg_t_grad, reg_a_grad, reg_b_grad = tape.gradient(ys, [t_points, f.a, f.b]) f, y0, t_points, _ = problems.construct_problem(TEST_DEVICE) y0 = tf.cast(y0, tf.float64) t_points = tf.cast(t_points, tf.float64) y0 = (y0,) func = lambda y0, t_points: tfdiffeq.odeint_adjoint(f, y0, t_points, method='dopri5') with tf.GradientTape() as tape: tape.watch(t_points) ys = func(y0, t_points) grads = tape.gradient(ys, [t_points, f.a, f.b]) adj_t_grad, adj_a_grad, adj_b_grad = grads self.assertLess(max_abs(reg_t_grad - adj_t_grad), 1.2e-7) self.assertLess(max_abs(reg_a_grad - adj_a_grad), 1.2e-7) self.assertLess(max_abs(reg_b_grad - adj_b_grad), 1.2e-7)
def test_dopri5_adjoint_against_dopri5(self): tf.keras.backend.set_floatx('float64') tf.compat.v1.set_random_seed(0) with tf.GradientTape(persistent=True) as tape: func, y0, t_points = self.problem() tape.watch(t_points) tape.watch(y0) ys = tfdiffeq.odeint_adjoint(func, y0, t_points, method='dopri5') gradys = 0.1 * tf.random.uniform(shape=ys.shape, dtype=tf.float64) adj_y0_grad, adj_t_grad, adj_A_grad = tape.gradient( ys, [y0, t_points, func.A], output_gradients=gradys) w_grad, b_grad = tape.gradient(ys, func.unused_module.variables) self.assertIsNone(w_grad) self.assertIsNone(b_grad) with tf.GradientTape() as tape: func, y0, t_points = self.problem() tape.watch(y0) tape.watch(t_points) ys = tfdiffeq.odeint(func, y0, t_points, method='dopri5') y_grad, t_grad, a_grad = tape.gradient(ys, [y0, t_points, func.A], output_gradients=gradys) self.assertLess(max_abs(y_grad - adj_y0_grad), 3e-4) self.assertLess(max_abs(t_grad - adj_t_grad), 1e-4) self.assertLess(max_abs(a_grad - adj_A_grad), 2e-3)
def test_adjoint(self): for ode in problems.PROBLEMS.keys(): f, y0, t_points, sol = problems.construct_problem(TEST_DEVICE, reverse=True) y0 = tf.cast(y0, tf.float64) t_points = tf.cast(t_points, tf.float64) sol = tf.cast(sol, tf.float64) y = tfdiffeq.odeint_adjoint(f, y0, t_points, method='dopri5') with self.subTest(ode=ode): self.assertLess(rel_error(sol, y), error_tol)
T = tf.constant(2., dtype=dtype) t = tf.cast(tf.linspace(0., T, 2), dtype) odemodel = ODE(a, b, dtype) for rtol in np.logspace(-13, 0, 14)[::-1]: print('rtol:', rtol) # Run forward and backward passes, while tracking the time with tf.device('/gpu:0'): t0 = time.time() with tf.GradientTape() as g: y_sol = odeint(odemodel, x_0, t, rtol=rtol, atol=1e-10)[-1] t1 = time.time() dYdX_backprop = g.gradient(y_sol, odemodel.b).numpy() t2 = time.time() with tf.GradientTape() as g: y_sol_adj = odeint_adjoint(odemodel, x_0, t, rtol=rtol, atol=1e-10)[-1] t3 = time.time() dYdX_adjoint = g.gradient(y_sol_adj, odemodel.b).numpy() t4 = time.time() dYdX_exact = exact_derivative(a, b, T).numpy() rel_err_adj = abs(dYdX_adjoint-dYdX_exact)/dYdX_exact rel_err_bp = abs(dYdX_backprop-dYdX_exact)/dYdX_exact print('Adjoint:', rel_err_adj, dtype) print('Backprop:', rel_err_bp, dtype) fd = open(file_path, 'a') fd.write('{},{},adjoint,{},{},{},{},{}\n'.format(dtype, rtol, rel_err_adj, t3-t2, t4-t3, odemodel.nfe.numpy(),
# Compute the reference gradient x_0_64 = tf.random.uniform( [16, 14, 14, 8], dtype=tf.float64) #tf.constant([[1., 10.]], dtype=tf.float64) t = tf.cast(tf.linspace(0., 2., 2), tf.float64) odemodel_exact = ODE(dtype=tf.float64) with tf.device('/gpu:0'): with tf.GradientTape() as g: y_sol = odeint(odemodel_exact, x_0_64, t, rtol=1e-15, atol=1e-15)[-1] dYdX_exact = g.gradient(y_sol, odemodel_exact.trainable_variables) with tf.GradientTape() as g: y_sol = odeint_adjoint(odemodel_exact, x_0_64, t, rtol=1e-15, atol=1e-15)[-1] dYdX_exact_adj = g.gradient(y_sol, odemodel_exact.trainable_variables) for x, x_ex in zip(dYdX_exact_adj, dYdX_exact): print((tf.norm(x - x_ex) / tf.norm(x_ex)).numpy()) print(odemodel_exact.summary()) for dtype in dtypes: if dtype == tf.float32: tf.keras.backend.set_floatx('float32') else: tf.keras.backend.set_floatx('float64') x_0 = tf.cast(x_0_64, dtype) t = tf.cast(tf.linspace(0., 2., 2), dtype)