Beispiel #1
0
def g_and_explicit_phi(prev_t, next_t, implicit_phi, k):
    curr_t = prev_t[0]
    dt = next_t - prev_t[0]

    with tf.device(prev_t[0].device):
        g = tf.Variable(tf.zeros([k + 1]), trainable=False)

    explicit_phi = collections.deque(maxlen=k)
    beta = move_to_device(tf.convert_to_tensor(1.), prev_t[0].device)

    # tf.assign(g[0], 1)
    compat.assign(g[0], 1)

    c = 1 / move_to_device(tf.range(1, k + 2), prev_t[0].device)
    explicit_phi.append(implicit_phi[0])

    beta = tf.cast(beta, next_t.dtype)
    for j in range(1, k):
        beta = (next_t - prev_t[j - 1]) / (curr_t - prev_t[j]) * beta
        beta_cast = move_to_device(beta, implicit_phi[j][0].device)
        beta_cast = tf.cast(beta_cast, implicit_phi[0][0].dtype)
        explicit_phi.append(
            tuple(iphi_ * beta_cast for iphi_ in implicit_phi[j]))

        c = c[:-1] - c[1:] if j == 1 else c[:-1] - c[1:] * dt / (next_t -
                                                                 prev_t[j - 1])
        # tf.assign(g[j], tf.cast(c[0], g[j].dtype))
        compat.assign(g[j], tf.cast(c[0], g[j].dtype))
        # g[j] = c[0]

    c = c[:-1] - c[1:] * dt / (next_t - prev_t[k - 1])
    # tf.assign(g[k], tf.cast(c[0], g[k].dtype))
    compat.assign(g[k], tf.cast(c[0], g[k].dtype))

    return g, explicit_phi
Beispiel #2
0
def _interp_evaluate(coefficients, t0, t1, t):
    """Evaluate polynomial interpolation at the given time point.

    Args:
        coefficients: list of Tensor coefficients as created by `interp_fit`.
        t0: scalar float64 Tensor giving the start of the interval.
        t1: scalar float64 Tensor giving the end of the interval.
        t: scalar float64 Tensor giving the desired interpolation point.

    Returns:
        Polynomial interpolation of the coefficients at time `t`.
    """

    dtype = coefficients[0][0].dtype
    device = coefficients[0][0].device

    t0 = _convert_to_tensor(t0, dtype=dtype, device=device)
    t1 = _convert_to_tensor(t1, dtype=dtype, device=device)
    t = _convert_to_tensor(t, dtype=dtype, device=device)

    assert (t0 <= t) & (
        t <=
        t1), 'invalid interpolation, fails `t0 <= t <= t1`: {}, {}, {}'.format(
            t0, t, t1)
    x = tf.cast(((t - t0) / (t1 - t0)), dtype)
    x = move_to_device(x, device)

    xs = [move_to_device(tf.convert_to_tensor(1, dtype=dtype), device), x]
    for _ in range(2, len(coefficients)):
        xs.append(xs[-1] * x)

    return tuple(
        _dot_product(coefficients_, reversed(xs))
        for coefficients_ in zip(*coefficients))
Beispiel #3
0
 def _linear_interp(self, t0, t1, y0, y1, t):
     if t == t0:
         return y0
     if t == t1:
         return y1
     t0 = move_to_device(t0, y0[0].device)
     t1 = move_to_device(t1, y0[0].device)
     t = move_to_device(t, y0[0].device)
     slope = tuple((y1_ - y0_) / (t1 - t0) for y0_, y1_, in zip(y0, y1))
     return tuple(y0_ + slope_ * (t - t0) for y0_, slope_ in zip(y0, slope))
Beispiel #4
0
    def before_integrate(self, t):
        prev_f = collections.deque(maxlen=self.max_order + 1)
        prev_t = collections.deque(maxlen=self.max_order + 1)
        phi = collections.deque(maxlen=self.max_order)

        t0 = t[0]
        f0 = self.func(tf.cast(t0, self.y0[0].dtype), self.y0)
        prev_t.appendleft(t0)
        prev_f.appendleft(f0)
        phi.appendleft(f0)
        first_step = _select_initial_step(self.func,
                                          t[0],
                                          self.y0,
                                          2,
                                          self.rtol[0],
                                          self.atol[0],
                                          f0=f0)
        first_step = move_to_device(first_step, t.device)
        first_step = tf.cast(first_step, t[0].dtype)

        self.vcabm_state = _VCABMState(self.y0,
                                       prev_f,
                                       prev_t,
                                       next_t=t[0] + first_step,
                                       phi=phi,
                                       order=1)
Beispiel #5
0
        def augmented_dynamics(t, y_aug):
            # Dynamics of the original system augmented with
            # the adjoint wrt y, and an integrator wrt t and args.

            y, adj_y = y_aug[:n_tensors], y_aug[
                n_tensors:2 * n_tensors]  # Ignore adj_time and adj_params.

            with tf.GradientTape() as tape:
                tape.watch(t)
                tape.watch(y)
                func_eval = func(t, y)
                func_eval = tf.convert_to_tensor(func_eval)

            gradys = -tf.stack(adj_y)
            if len(gradys.shape) < len(func_eval.shape):
                gradys = tf.expand_dims(gradys, axis=0)
            vjp_t, *vjp_y_and_params = tape.gradient(
                func_eval, (t, ) + y + f_params,
                output_gradients=gradys,
                unconnected_gradients=tf.UnconnectedGradients.ZERO)

            vjp_y = vjp_y_and_params[:n_tensors]
            vjp_params = vjp_y_and_params[n_tensors:]
            vjp_params = _flatten(vjp_params)

            if _check_len(f_params) == 0:
                vjp_params = tf.convert_to_tensor(0., dtype=vjp_y[0].dype)
                vjp_params = move_to_device(vjp_params, vjp_y[0].device)

            return (*func_eval, *vjp_y, vjp_t, vjp_params)
Beispiel #6
0
        def augmented_dynamics(t, y_aug):
            # Dynamics of the original system augmented with
            # the adjoint wrt y, and an integrator wrt t and args.

            y, adj_y = y_aug[:n_tensors], y_aug[
                n_tensors:2 * n_tensors]  # Ignore adj_time and adj_params.

            with tf.GradientTape() as tape:
                tape.watch(t)
                tape.watch(y)
                func_eval = func(t, y)
                func_eval = cast_double(func_eval)

            gradys = tf.stack(list(-adj_y_ for adj_y_ in adj_y))
            if len(gradys.shape) < len(func_eval.shape):
                gradys = tf.expand_dims(gradys, axis=0)
            vjp_t, *vjp_y_and_params = tape.gradient(func_eval,
                                                     (t, ) + y + f_params,
                                                     output_gradients=gradys)

            vjp_y = vjp_y_and_params[:n_tensors]
            vjp_params = vjp_y_and_params[n_tensors:]

            # autograd.grad returns None if no gradient, set to zero.
            vjp_t = tf.zeros_like(t, dtype=t.dtype) if vjp_t is None else vjp_t
            vjp_y = tuple(
                tf.zeros_like(y_, dtype=y_.dtype) if vjp_y_ is None else vjp_y_
                for vjp_y_, y_ in zip(vjp_y, y))
            vjp_params = _flatten_convert_none_to_zeros(vjp_params, f_params)

            if _check_len(f_params) == 0:
                vjp_params = tf.convert_to_tensor(0., dtype=vjp_y[0].dype)
                vjp_params = move_to_device(vjp_params, vjp_y[0].device)

            return (*func_eval, *vjp_y, vjp_t, vjp_params)
 def before_integrate(self, t):
     f0 = self.func(tf.cast(t[0], self.y0[0].dtype), self.y0)
     if self.first_step is None:
         first_step = _select_initial_step(self.func, t[0], self.y0, 1, self.rtol[0], self.atol[0], f0=f0)
         first_step = move_to_device(tf.cast(first_step, t.dtype), t.device)
     else:
         first_step = _convert_to_tensor(self.first_step, dtype=t.dtype, device=t.device)
     self.rk_state = _RungeKuttaState(self.y0, f0, t[0], t[0], first_step, interp_coeff=[self.y0] * 5)
Beispiel #8
0
 def integrate(self, t):
     _assert_increasing(t)
     solution = [self.y0]
     t = move_to_device(tf.cast(t, tf.float64), self.y0[0].device)
     self.before_integrate(t)
     for i in range(1, t.shape[0]):
         y = self.advance(t[i])
         solution.append(y)
     return tuple(map(tf.stack, tuple(zip(*solution))))
Beispiel #9
0
        def _grid_constructor(func, y0, t):
            start_time = t[0]
            end_time = t[-1]

            niters = tf.ceil((end_time - start_time) / step_size + 1).item()
            t_infer = move_to_device(tf.range(0, niters), t) * step_size + start_time
            if t_infer[-1] > t[-1]:
                t_infer[-1] = t[-1]

            return t_infer
Beispiel #10
0
    def before_integrate(self, t):
        if self.first_step is None:
            first_step = _select_initial_step(self.func, t[0], self.y0, 4,
                                              self.rtol, self.atol)
            first_step = move_to_device(tf.cast(first_step, t.dtype), t.device)
        else:
            first_step = _convert_to_tensor(self.first_step,
                                            dtype=t.dtype,
                                            device=t.device)

        self.rk_state = _RungeKuttaState(self.y0, self.func(t[0], self.y0),
                                         t[0], t[0], first_step, [self.y0] * 7)
Beispiel #11
0
        def augmented_dynamics(t, y_aug):
            # Dynamics of the original system augmented with
            # the adjoint wrt y, and an integrator wrt t and args.

            y, adj_y = y_aug[:n_tensors], y_aug[
                n_tensors:2 * n_tensors]  # Ignore adj_time and adj_params.

            # t = tf.get_variable('t', initializer=t)
            # y = tuple(tf.Variable(y_) for y_ in y)

            with tf.GradientTape() as tape:
                tape.watch(t)
                tape.watch(y)
                func_eval = func(t, y)
                func_eval = cast_double(func_eval)

            # print('y', [y_.numpy().shape for y_ in y])
            # print('adj y', [a.numpy().shape for a in adj_y])

            vjp_t, *vjp_y_and_params = tape.gradient(
                func_eval,
                (t, ) + y + f_params,
                # list(-adj_y_ for adj_y_ in adj_y),
            )

            vjp_y = vjp_y_and_params[:n_tensors]
            vjp_params = vjp_y_and_params[n_tensors:]
            # print('vjp_y', [v.numpy().shape if v is not None else None for v in vjp_y])
            # print()

            # autograd.grad returns None if no gradient, set to zero.
            vjp_t = tf.zeros_like(t, dtype=t.dtype) if vjp_t is None else vjp_t
            vjp_y = tuple(
                tf.zeros_like(y_, dtype=y_.dtype) if vjp_y_ is None else vjp_y_
                for vjp_y_, y_ in zip(vjp_y, y))
            vjp_params = _flatten_convert_none_to_zeros(vjp_params, f_params)

            if _check_len(f_params) == 0:
                vjp_params = tf.convert_to_tensor(0., dtype=vjp_y[0].dype)
                vjp_params = move_to_device(vjp_params, vjp_y[0].device)

            # print('vjp_t grad', vjp_t.numpy())
            # print('vjp_params', [v.numpy() for v in vjp_params])
            # print('vjp y grads', [v.numpy().shape for v in vjp_y])
            # print("LEN FUNC EVALS : ", len(func_eval))
            # print()

            return (*func_eval, *vjp_y, vjp_t, vjp_params)
Beispiel #12
0
    def integrate(self, t):
        _assert_increasing(t)
        t = tf.cast(t, self.y0[0].dtype)
        time_grid = self.grid_constructor(self.func, self.y0, t)
        assert tf.equal(time_grid[0], t[0]) and tf.equal(time_grid[-1], t[-1])
        time_grid = move_to_device(time_grid, self.y0[0].device)

        solution = [cast_double(self.y0)]

        j = 1
        y0 = cast_double(self.y0)
        for t0, t1 in zip(time_grid[:-1], time_grid[1:]):
            dy = self.step_func(self.func, t0, t1 - t0, y0)
            y1 = tuple(y0_ + dy_ for y0_, dy_ in zip(y0, dy))
            y0 = y1

            while j < t.shape[0] and t1 >= t[j]:
                y = self._linear_interp(t0, t1, y0, y1, t[j])
                solution.append(y)
                j += 1

        return tuple(map(tf.stack, tuple(zip(*solution))))
Beispiel #13
0
    def grad(*grad_output, variables=None):
        global _arguments
        flat_params = _flatten(variables)

        func = _arguments.func
        adjoint_method = _arguments.adjoint_method
        adjoint_rtol = _arguments.rtol
        adjoint_atol = _arguments.atol
        adjoint_options = _arguments.adjoint_options

        n_tensors = len(ans)
        f_params = tuple(variables)

        # TODO: use a tf.keras.Model and call odeint_adjoint to implement higher order derivatives.
        def augmented_dynamics(t, y_aug):
            # Dynamics of the original system augmented with
            # the adjoint wrt y, and an integrator wrt t and args.

            y, adj_y = y_aug[:n_tensors], y_aug[n_tensors:2 * n_tensors]  # Ignore adj_time and adj_params.

            with tf.GradientTape() as tape:
                tape.watch(t)
                tape.watch(y)
                func_eval = func(t, y)
                func_eval = tf.convert_to_tensor(func_eval)

            gradys = -tf.stack(adj_y)
            if type(func_eval) in [list, tuple]:
                for eval_ix in range(len(func_eval)):
                    if len(gradys[eval_ix].shape) < len(func_eval[eval_ix].shape):
                        gradys[eval_ix] = tf.expand_dims(gradys[eval_ix], axis=0)

            else:
                if len(gradys.shape) < len(func_eval.shape):
                    gradys = tf.expand_dims(gradys, axis=0)
            vjp_t, *vjp_y_and_params = tape.gradient(
                func_eval,
                (t,) + y + f_params,
                output_gradients=gradys,
                unconnected_gradients=tf.UnconnectedGradients.ZERO
            )

            vjp_y = vjp_y_and_params[:n_tensors]
            vjp_params = vjp_y_and_params[n_tensors:]
            vjp_params = _flatten(vjp_params)

            if _check_len(f_params) == 0:
                vjp_params = tf.convert_to_tensor(0., dtype=vjp_y[0].dype)
                vjp_params = move_to_device(vjp_params, vjp_y[0].device)

            return (*func_eval, *vjp_y, vjp_t, vjp_params)

        T = ans[0].shape[0]
        if isinstance(grad_output, (tf.Tensor, tf.Variable)):
            adj_y = [grad_output[-1]]
        else:
            adj_y = tuple(grad_output_[-1] for grad_output_ in grad_output)

        adj_params = tf.zeros_like(flat_params, dtype=flat_params.dtype)
        adj_time = move_to_device(tf.convert_to_tensor(0., dtype=t.dtype), t.device)
        time_vjps = []
        for i in range(T - 1, 0, -1):

            ans_i = tuple(ans_[i] for ans_ in ans)

            if isinstance(grad_output, (tf.Tensor, tf.Variable)):
                grad_output_i = [grad_output[i]]
            else:
                grad_output_i = tuple(grad_output_[i] for grad_output_ in grad_output)

            func_i = func(t[i], ans_i)

            if not isinstance(func_i, Iterable):
                func_i = [func_i]

            # Compute the effect of moving the current time measurement point.
            dLd_cur_t = sum(
                tf.reshape(tf.matmul(tf.reshape(func_i_, [1, -1]), tf.reshape(grad_output_i_, [-1, 1])), [1])
                for func_i_, grad_output_i_ in zip(func_i, grad_output_i)
            )

            adj_time = adj_time - dLd_cur_t
            time_vjps.append(dLd_cur_t)

            # Run the augmented system backwards in time.
            if isinstance(adj_params, Iterable):
                if _numel(adj_params) == 0:
                    adj_params = move_to_device(tf.convert_to_tensor(0., dtype=adj_y[0].dtype), adj_y[0].device)

            aug_y0 = (*ans_i, *adj_y, adj_time, adj_params)

            aug_ans = odeint(
                augmented_dynamics,
                aug_y0,
                tf.convert_to_tensor([t[i], t[i - 1]]),
                rtol=adjoint_rtol, atol=adjoint_atol, method=adjoint_method, options=adjoint_options
            )

            # Unpack aug_ans.
            adj_y = aug_ans[n_tensors:2 * n_tensors]
            adj_time = aug_ans[2 * n_tensors]
            adj_params = aug_ans[2 * n_tensors + 1]

            adj_y = tuple(adj_y_[1] if _check_len(adj_y_) > 0 else adj_y_ for adj_y_ in adj_y)
            if _check_len(adj_time) > 0: adj_time = adj_time[1]
            if _check_len(adj_params) > 0: adj_params = adj_params[1]

            adj_y = tuple(adj_y_ + grad_output_[i - 1] for adj_y_, grad_output_ in zip(adj_y, grad_output))

            del aug_y0, aug_ans

        time_vjps.append(adj_time)
        time_vjps = tf.concat(time_vjps[::-1], 0)

        # reshape the parameters back into the correct variable shapes
        var_flat_lens = [_numel(v, dtype=tf.int32).numpy() for v in variables]
        var_shapes = [v.shape for v in variables]

        adj_params_splits = tf.split(adj_params, var_flat_lens)
        adj_params_list = [tf.reshape(p, v_shape)
                           for p, v_shape in zip(adj_params_splits, var_shapes)]
        return (*adj_y, time_vjps), adj_params_list
Beispiel #14
0
    def _adaptive_adams_step(self, vcabm_state, final_t):
        y0, prev_f, prev_t, next_t, prev_phi, order = vcabm_state
        if next_t > final_t:
            next_t = final_t
        dt = (next_t - prev_t[0])
        dt_cast = move_to_device(dt, y0[0].device)
        dt_cast = tf.cast(dt_cast, y0[0].dtype)

        # Explicit predictor step.
        g, phi = g_and_explicit_phi(prev_t, next_t, prev_phi, order)
        # g = move_to_device(g, y0[0].device)

        g = tf.cast(g, dt_cast.dtype)
        # phi = [tf.cast(phi_, dt_cast.dtype) for phi_ in phi]

        p_next = tuple(
            y0_ +
            _scaled_dot_product(dt_cast, g[:max(1, order -
                                                1)], phi_[:max(1, order - 1)])
            for y0_, phi_ in zip(y0, tuple(zip(*phi))))

        # Update phi to implicit.
        next_t = move_to_device(next_t, p_next[0].device)
        next_f0 = self.func(tf.cast(next_t, self.y0[0].dtype), p_next)
        implicit_phi_p = compute_implicit_phi(phi, next_f0, order + 1)

        # Implicit corrector step.
        y_next = tuple(
            p_next_ + dt_cast * g[order - 1] * tf.cast(iphi_, dt_cast.dtype)
            for p_next_, iphi_ in zip(p_next, implicit_phi_p[order - 1]))

        # Error estimation.
        tolerance = tuple(
            atol_ +
            rtol_ * tf.reduce_max([tf.abs(y0_), tf.abs(y1_)])
            for atol_, rtol_, y0_, y1_ in zip(self.atol, self.rtol, y0,
                                              y_next))
        local_error = tuple(dt_cast * (g[order] - g[order - 1]) *
                            tf.cast(iphi_, dt_cast.dtype)
                            for iphi_ in implicit_phi_p[order])
        error_k = _compute_error_ratio(local_error, tolerance)
        accept_step = tf.reduce_all((tf.convert_to_tensor(error_k) <= 1))

        if not accept_step:
            # Retry with adjusted step size if step is rejected.
            dt_next = _optimal_step_size(dt,
                                         error_k,
                                         self.safety,
                                         self.ifactor,
                                         self.dfactor,
                                         order=order)
            return _VCABMState(y0,
                               prev_f,
                               prev_t,
                               prev_t[0] + dt_next,
                               prev_phi,
                               order=order)

        # We accept the step. Evaluate f and update phi.
        next_t = move_to_device(next_t, p_next[0].device)
        next_f0 = self.func(tf.cast(next_t, self.y0[0].dtype), y_next)
        implicit_phi = compute_implicit_phi(phi, next_f0, order + 2)
        next_order = order

        if len(prev_t) <= 4 or order < 3:
            next_order = min(order + 1, 3, self.max_order)
        else:
            error_km1 = _compute_error_ratio(
                tuple(dt_cast * (g[order - 1] - g[order - 2]) * iphi_
                      for iphi_ in implicit_phi_p[order - 1]), tolerance)
            error_km2 = _compute_error_ratio(
                tuple(dt_cast * (g[order - 2] - g[order - 3]) * iphi_
                      for iphi_ in implicit_phi_p[order - 2]), tolerance)
            if min(error_km1 + error_km2) < max(error_k):
                next_order = order - 1
            elif order < self.max_order:
                error_kp1 = _compute_error_ratio(
                    tuple(dt_cast * gamma_star[order] * iphi_
                          for iphi_ in implicit_phi_p[order]), tolerance)
                if max(error_kp1) < max(error_k):
                    next_order = order + 1

        # Keep step size constant if increasing order. Else use adaptive step size.
        dt_next = dt if next_order > order else _optimal_step_size(
            dt,
            error_k,
            self.safety,
            self.ifactor,
            self.dfactor,
            order=order + 1)

        prev_f.appendleft(next_f0)
        prev_t.appendleft(next_t)
        return _VCABMState(p_next,
                           prev_f,
                           prev_t,
                           next_t + dt_next,
                           implicit_phi,
                           order=next_order)
Beispiel #15
0
    def grad(*grad_output, variables=None):
        global _arguments
        # t, flat_params, *ans = ctx.saved_tensors
        # ans = tuple(ans)
        # func, rtol, atol, method, options = ctx.func, ctx.rtol, ctx.atol, ctx.method, ctx.options
        func = _arguments.func
        method = _arguments.method
        options = _arguments.options
        rtol = _arguments.rtol
        atol = _arguments.atol

        print("Gradient Output : ", grad_output)
        print("Variables : ", variables)

        n_tensors = len(ans)
        f_params = tuple(variables)

        # TODO: use a tf.keras.Model and call odeint_adjoint to implement higher order derivatives.
        def augmented_dynamics(t, y_aug):
            # Dynamics of the original system augmented with
            # the adjoint wrt y, and an integrator wrt t and args.

            y, adj_y = y_aug[:n_tensors], y_aug[
                n_tensors:2 * n_tensors]  # Ignore adj_time and adj_params.

            # t = tf.get_variable('t', initializer=t)
            # y = tuple(tf.Variable(y_) for y_ in y)

            with tf.GradientTape() as tape:
                tape.watch(t)
                tape.watch(y)
                func_eval = func(t, y)
                func_eval = cast_double(func_eval)

            # print('y', [y_.numpy().shape for y_ in y])
            # print('adj y', [a.numpy().shape for a in adj_y])

            vjp_t, *vjp_y_and_params = tape.gradient(
                func_eval,
                (t, ) + y + f_params,
                # list(-adj_y_ for adj_y_ in adj_y),
            )

            vjp_y = vjp_y_and_params[:n_tensors]
            vjp_params = vjp_y_and_params[n_tensors:]
            # print('vjp_y', [v.numpy().shape if v is not None else None for v in vjp_y])
            # print()

            # autograd.grad returns None if no gradient, set to zero.
            vjp_t = tf.zeros_like(t, dtype=t.dtype) if vjp_t is None else vjp_t
            vjp_y = tuple(
                tf.zeros_like(y_, dtype=y_.dtype) if vjp_y_ is None else vjp_y_
                for vjp_y_, y_ in zip(vjp_y, y))
            vjp_params = _flatten_convert_none_to_zeros(vjp_params, f_params)

            if _check_len(f_params) == 0:
                vjp_params = tf.convert_to_tensor(0., dtype=vjp_y[0].dype)
                vjp_params = move_to_device(vjp_params, vjp_y[0].device)

            # print('vjp_t grad', vjp_t.numpy())
            # print('vjp_params', [v.numpy() for v in vjp_params])
            # print('vjp y grads', [v.numpy().shape for v in vjp_y])
            # print("LEN FUNC EVALS : ", len(func_eval))
            # print()

            return (*func_eval, *vjp_y, vjp_t, vjp_params)

        T = ans[0].shape[0]
        if isinstance(grad_output, tf.Tensor) or isinstance(
                grad_output, tf.Variable):
            adj_y = [grad_output[-1]]
        else:
            adj_y = tuple(grad_output_[-1] for grad_output_ in grad_output)
        # adj_y = tuple(grad_output_[-1] for grad_output_ in grad_output)
        adj_params = tf.zeros_like(flat_params, dtype=flat_params.dtype)
        adj_time = move_to_device(tf.convert_to_tensor(0., dtype=t.dtype),
                                  t.device)
        time_vjps = []
        for i in range(T - 1, 0, -1):

            ans_i = tuple(ans_[i] for ans_ in ans)

            if isinstance(grad_output, tf.Tensor) or isinstance(
                    grad_output, tf.Variable):
                grad_output_i = [grad_output[i]]
            else:
                grad_output_i = tuple(grad_output_[i]
                                      for grad_output_ in grad_output)

            func_i = func(t[i], ans_i)
            func_i = cast_double(func_i)

            if not isinstance(func_i, Iterable):
                func_i = [func_i]

            # Compute the effect of moving the current time measurement point.
            dLd_cur_t = sum(
                tf.reshape(
                    tf.matmul(tf.reshape(func_i_, [1, -1]),
                              tf.reshape(grad_output_i_, [-1, 1])), [1])
                for func_i_, grad_output_i_ in zip(func_i, grad_output_i))
            adj_time = cast_double(adj_time)
            adj_time = adj_time - dLd_cur_t
            time_vjps.append(dLd_cur_t)

            # Run the augmented system backwards in time.
            if isinstance(adj_params, Iterable):
                count = _numel(adj_params)

                if count == 0:
                    adj_params = move_to_device(
                        tf.convert_to_tensor(0., dtype=adj_y[0].dtype),
                        adj_y[0].device)

            aug_y0 = (*ans_i, *adj_y, adj_time, adj_params)

            # print('ans i', [a.numpy().shape for a in ans_i])
            # print('adj y', [a.numpy().shape for a in adj_y])
            # print('adj time', adj_time.numpy().shape)
            # print('adj params', adj_params.numpy().shape)

            # print()

            aug_ans = odeint(augmented_dynamics,
                             aug_y0,
                             tf.convert_to_tensor([t[i], t[i - 1]]),
                             rtol=rtol,
                             atol=atol,
                             method=method,
                             options=options)

            # Unpack aug_ans.
            adj_y = aug_ans[n_tensors:2 * n_tensors]
            adj_time = aug_ans[2 * n_tensors]
            adj_params = aug_ans[2 * n_tensors + 1]

            adj_y = tuple(adj_y_[1] if _check_len(adj_y_) > 0 else adj_y_
                          for adj_y_ in adj_y)
            if _check_len(adj_time) > 0: adj_time = adj_time[1]
            if _check_len(adj_params) > 0: adj_params = adj_params[1]

            adj_y = tuple(adj_y_ + grad_output_[i - 1]
                          for adj_y_, grad_output_ in zip(adj_y, grad_output))

            del aug_y0, aug_ans

        time_vjps.append(adj_time)
        time_vjps = tf.concat(time_vjps[::-1], 0)

        print()
        print('adj y', len(adj_y))
        print('time vjps', time_vjps.shape)
        print('adj params', adj_params.shape)
        print()

        # reshape the parameters back into the correct variable shapes
        var_flat_lens = [_numel(v, dtype=tf.int32).numpy() for v in variables]
        var_shapes = [v.shape for v in variables]

        adj_params_splits = tf.split(adj_params, var_flat_lens)
        adj_params_list = [
            tf.reshape(p, v_shape)
            for p, v_shape in zip(adj_params_splits, var_shapes)
        ]

        # add the time gradient (always the first tensor in list of variables)
        # adj_params.insert(0, time_vjps)

        # adj_y_grad_vars = list(zip(adj_y, grad_output))
        # time_grad_vars = list((time_vjps, t))
        model_vars = list(adj_params_list)  # list(zip(adj_params, variables))

        grad_vars = model_vars  # adj_y_grad_vars + time_grad_vars + model_vars
        # print('adj y grad', len(adj_y_grad_vars))
        # print('time grad', len(time_grad_vars))
        print('model grad', len(model_vars))
        print('model grad values', [v for v in grad_vars])
        print()

        # if len(adj_y) == 1:
        #     adj_y = adj_y[0]

        return (adj_y, model_vars)