def test_grad_no_grad(self): class CosGradModel(tf.keras.Model): def call(self, inputs, **kwargs): t, x = inputs with tf.GradientTape() as g: g.watch(x) y = tf.sin(t * x) dy = g.gradient(y, x) return dy class CosModel(tf.keras.Model): def call(self, inputs, **kwargs): t, x = inputs y = t * tf.cos(t * x) return y t_vec = np.linspace(0, 1.0, 40) ode = NeuralODE(CosModel(), t=t_vec, solver=neural_ode.rk4_step) x0 = tf.to_float([1.0]) cos_xN = ode.forward(x0) x0_rec, dLdx0, _ = ode.backward(cos_xN, cos_xN) ode = NeuralODE(CosGradModel(), t=t_vec, solver=neural_ode.rk4_step) cos_grad_xN = ode.forward(x0) x0_grad_rec, dLdx0_grad, _ = ode.backward(cos_grad_xN, cos_grad_xN) self.assertAllClose(cos_grad_xN, cos_xN) self.assertAllClose(x0_grad_rec, x0_rec) self.assertAllClose(x0, x0_rec) self.assertAllClose(dLdx0, dLdx0_grad) ode = NeuralODE(CosModel(), t=t_vec, solver=neural_ode.rk4_step) with tf.GradientTape() as g: g.watch(x0) cos_xN = ode.forward(x0) loss = 0.5 * cos_xN ** 2 dLdx0_exact = g.gradient(loss, [x0]) self.assertAllClose(dLdx0_exact, dLdx0_grad) # build static graph ode = NeuralODE(CosGradModel(), t=t_vec, solver=neural_ode.rk4_step) ode = neural_ode.defun_neural_ode(ode) cos_grad_xN = ode.forward(x0) x0_grad_rec, dLdx0_grad, _ = ode.backward(cos_grad_xN, cos_grad_xN) self.assertAllClose(dLdx0_grad, dLdx0_exact) self.assertAllClose(x0_grad_rec, x0)
def test_net_with_inner_gradient(self): tf.set_random_seed(1234) t_vec = np.linspace(0, 1.0, 20) model = NNGradientModule() ode = NeuralODE(model, t=t_vec, solver=neural_ode.rk4_step) xy0 = tf.random_normal(shape=[12, 2 * 3]) xyN = ode.forward(xy0) with tf.GradientTape() as g: g.watch(xyN) loss = xyN ** 2 dLoss = g.gradient(loss, xyN) xy0_rec, dLdxy0, dLdW = ode.backward(xyN, dLoss) self.assertAllClose(xy0_rec, xy0) with tf.GradientTape() as g: g.watch(xy0) xyN = ode.forward(xy0) loss = xyN ** 2 dLdxy0_exact, *dLdW_exact = g.gradient(loss, [xy0, *model.weights]) self.assertAllClose(dLdxy0_exact, dLdxy0) self.assertAllClose(dLdW_exact, dLdW)
def test_backward_none(self): tf.set_random_seed(1234) t_grid = np.linspace(0, 1.0, 15) x0 = tf.random_normal(shape=[7, 3]) ode = NeuralODE(NNModuleTimeDependent(), t=t_grid) x0_rec, *_ = ode.backward(ode.forward(x0)) self.assertAllClose(x0_rec, x0)
def test_nn_forward_backward(self): tf.set_random_seed(1234) t_vec = np.linspace(0, 1.0, 20) model = NNModule() ode = NeuralODE(model, t=t_vec, solver=neural_ode.rk4_step) x0 = tf.random_normal(shape=[12, 3]) xN = ode.forward(x0) dLoss = 2 * xN # explicit gradient od x**2 x0_rec, dLdx0, dLdW = ode.backward(xN, dLoss) self.assertAllClose(x0_rec.numpy(), x0.numpy()) with tf.GradientTape() as g: g.watch(x0) xN = ode.forward(x0) loss = xN ** 2 dLdx0_exact, *dLdW_exact = g.gradient(loss, [x0, *model.weights]) self.assertAllClose(dLdx0_exact, dLdx0) self.assertAllClose(dLdW_exact, dLdW)
def test_backprop(self): t_max = 1 t_grid = np.linspace(0, t_max, 40) ode = NeuralODE(SineDumpingModel(), t=t_grid, solver=neural_ode.rk4_step) x0 = tf.to_float([0]) xN = ode.forward(x0) with tf.GradientTape() as g: g.watch(xN) loss = xN ** 2 dLoss = g.gradient(loss, xN) x0_rec, dLdx0, dLdW = ode.backward(xN, dLoss) self.assertAllClose(x0_rec.numpy(), x0) self.assertEqual(dLdW, []) with tf.GradientTape() as g: g.watch(x0) xN = ode.forward(x0) loss = xN ** 2 dLdx0_exact = g.gradient(loss, x0) self.assertAllClose(dLdx0_exact, dLdx0)