def test_grad_no_grad(self):
        class CosGradModel(tf.keras.Model):
            def call(self, inputs, **kwargs):
                t, x = inputs
                with tf.GradientTape() as g:
                    g.watch(x)
                    y = tf.sin(t * x)
                    dy = g.gradient(y, x)

                return dy

        class CosModel(tf.keras.Model):
            def call(self, inputs, **kwargs):
                t, x = inputs
                y = t * tf.cos(t * x)
                return y

        t_vec = np.linspace(0, 1.0, 40)
        ode = NeuralODE(CosModel(), t=t_vec, solver=neural_ode.rk4_step)
        x0 = tf.to_float([1.0])

        cos_xN = ode.forward(x0)
        x0_rec, dLdx0, _ = ode.backward(cos_xN, cos_xN)

        ode = NeuralODE(CosGradModel(), t=t_vec, solver=neural_ode.rk4_step)
        cos_grad_xN = ode.forward(x0)
        x0_grad_rec, dLdx0_grad, _ = ode.backward(cos_grad_xN, cos_grad_xN)

        self.assertAllClose(cos_grad_xN, cos_xN)
        self.assertAllClose(x0_grad_rec, x0_rec)
        self.assertAllClose(x0, x0_rec)
        self.assertAllClose(dLdx0, dLdx0_grad)

        ode = NeuralODE(CosModel(), t=t_vec, solver=neural_ode.rk4_step)
        with tf.GradientTape() as g:
            g.watch(x0)
            cos_xN = ode.forward(x0)
            loss = 0.5 * cos_xN ** 2

        dLdx0_exact = g.gradient(loss, [x0])

        self.assertAllClose(dLdx0_exact, dLdx0_grad)

        # build static graph
        ode = NeuralODE(CosGradModel(), t=t_vec, solver=neural_ode.rk4_step)
        ode = neural_ode.defun_neural_ode(ode)
        cos_grad_xN = ode.forward(x0)
        x0_grad_rec, dLdx0_grad, _ = ode.backward(cos_grad_xN, cos_grad_xN)

        self.assertAllClose(dLdx0_grad, dLdx0_exact)
        self.assertAllClose(x0_grad_rec, x0)
    def test_net_with_inner_gradient(self):
        tf.set_random_seed(1234)
        t_vec = np.linspace(0, 1.0, 20)
        model = NNGradientModule()
        ode = NeuralODE(model, t=t_vec, solver=neural_ode.rk4_step)

        xy0 = tf.random_normal(shape=[12, 2 * 3])
        xyN = ode.forward(xy0)

        with tf.GradientTape() as g:
            g.watch(xyN)
            loss = xyN ** 2
            dLoss = g.gradient(loss, xyN)

        xy0_rec, dLdxy0, dLdW = ode.backward(xyN, dLoss)
        self.assertAllClose(xy0_rec, xy0)

        with tf.GradientTape() as g:
            g.watch(xy0)
            xyN = ode.forward(xy0)
            loss = xyN ** 2

        dLdxy0_exact, *dLdW_exact = g.gradient(loss, [xy0, *model.weights])
        self.assertAllClose(dLdxy0_exact, dLdxy0)
        self.assertAllClose(dLdW_exact, dLdW)
    def test_backward_none(self):
        tf.set_random_seed(1234)
        t_grid = np.linspace(0, 1.0, 15)

        x0 = tf.random_normal(shape=[7, 3])

        ode = NeuralODE(NNModuleTimeDependent(), t=t_grid)
        x0_rec, *_ = ode.backward(ode.forward(x0))
        self.assertAllClose(x0_rec, x0)
    def test_nn_forward_backward(self):
        tf.set_random_seed(1234)
        t_vec = np.linspace(0, 1.0, 20)
        model = NNModule()
        ode = NeuralODE(model, t=t_vec, solver=neural_ode.rk4_step)

        x0 = tf.random_normal(shape=[12, 3])
        xN = ode.forward(x0)
        dLoss = 2 * xN  # explicit gradient od x**2
        x0_rec, dLdx0, dLdW = ode.backward(xN, dLoss)
        self.assertAllClose(x0_rec.numpy(), x0.numpy())

        with tf.GradientTape() as g:
            g.watch(x0)
            xN = ode.forward(x0)
            loss = xN ** 2

        dLdx0_exact, *dLdW_exact = g.gradient(loss, [x0, *model.weights])

        self.assertAllClose(dLdx0_exact, dLdx0)
        self.assertAllClose(dLdW_exact, dLdW)
    def test_backprop(self):
        t_max = 1
        t_grid = np.linspace(0, t_max, 40)

        ode = NeuralODE(SineDumpingModel(), t=t_grid,
                        solver=neural_ode.rk4_step)
        x0 = tf.to_float([0])
        xN = ode.forward(x0)
        with tf.GradientTape() as g:
            g.watch(xN)
            loss = xN ** 2

        dLoss = g.gradient(loss, xN)
        x0_rec, dLdx0, dLdW = ode.backward(xN, dLoss)
        self.assertAllClose(x0_rec.numpy(), x0)
        self.assertEqual(dLdW, [])

        with tf.GradientTape() as g:
            g.watch(x0)
            xN = ode.forward(x0)
            loss = xN ** 2

        dLdx0_exact = g.gradient(loss, x0)
        self.assertAllClose(dLdx0_exact, dLdx0)