def aug_func(x_aug, t):
            x = x_aug[:n_states]
            dxdp = tf.reshape(x_aug[n_states:], [n_states, n_theta + n_ivs])

            dxdt = func(x, t)
            d_dxdp_dt = tf.matmul(dfdx(x, t), dxdp) + dfdp(x, t)
            return flatten([dxdt, d_dxdp_dt])
Exemple #2
0
        def augmented_dynamics(y_aug, t):
            tpos = -t
            y, adj_y = y_aug[:n_states], y_aug[n_states:2 * n_states]

            with tf.GradientTape() as tape:
                tape.watch(tpos)
                tape.watch(y)
                func_eval = func(y, tpos)

            gradys = -adj_y
            if len(gradys.shape) < len(func_eval.shape):
                gradys = tf.expand_dims(gradys, axis=0)

            vjp_t, vjp_y, vjp_params = tape.gradient(
                func_eval, (t, y) + f_params,
                output_gradients=gradys,
                unconnected_gradients=tf.UnconnectedGradients.ZERO)
            return -flatten((func_eval, vjp_y, vjp_t, vjp_params))
Exemple #3
0
    def grad_fn(*grad_output, **kwargs):
        variables = kwargs.get('variables', None)
        global _arguments
        f_params = tuple(variables)
        flat_params = flatten(variables)
        func = _arguments.func
        method = _arguments.method
        options = _arguments.options
        rtol = _arguments.rtol
        atol = _arguments.atol
        n_states = ans.shape[0] if len(ans.shape) == 1 else ans.shape[1]

        def augmented_dynamics(y_aug, t):
            tpos = -t
            y, adj_y = y_aug[:n_states], y_aug[n_states:2 * n_states]

            with tf.GradientTape() as tape:
                tape.watch(tpos)
                tape.watch(y)
                func_eval = func(y, tpos)

            gradys = -adj_y
            if len(gradys.shape) < len(func_eval.shape):
                gradys = tf.expand_dims(gradys, axis=0)

            vjp_t, vjp_y, vjp_params = tape.gradient(
                func_eval, (t, y) + f_params,
                output_gradients=gradys,
                unconnected_gradients=tf.UnconnectedGradients.ZERO)
            return -flatten((func_eval, vjp_y, vjp_t, vjp_params))

        # Backward integration using augmented state
        T = ans.shape[0]
        adj_y = grad_output[-1][-1]
        adj_params = tf.zeros_like(flat_params, dtype=flat_params.dtype)
        adj_time = tf.convert_to_tensor(0., dtype=t.dtype)
        time_vjps = []
        for i in range(T - 1, 0, -1):
            func_i = func(ans[i], t[i])
            grad_output_i = grad_output[-1][i]

            # Compute the effect of moving the current time measurement point.
            dLd_cur_t = tf.tensordot(tf.squeeze(func_i),
                                     tf.squeeze(grad_output_i), 1)
            adj_time = adj_time - dLd_cur_t
            time_vjps.append(dLd_cur_t)

            aug_y0 = flatten((ans[i], adj_y, adj_time, adj_params))

            aug_ans = odeint(augmented_dynamics,
                             aug_y0,
                             tf.convert_to_tensor([-t[i], -t[i - 1]]),
                             rtol=rtol,
                             atol=atol,
                             method=method,
                             options=options)

            # Unpack aug_ans.
            adj_y = aug_ans[1, n_states:2 * n_states]
            adj_time = aug_ans[1, 2 * n_states]
            adj_params = aug_ans[1, 2 * n_states + 1:]

            adj_y += grad_output[-1][i - 1]

        time_vjps.append(adj_time)
        time_vjps = flatten(time_vjps)

        adj_params_list = []
        beg = 0
        for v in variables:
            shape = v.shape
            size = tf.size(v)
            end = beg + size
            adj_params_list.append(tf.reshape(adj_params[beg:end], shape))
            beg = end

        return (adj_y, time_vjps), adj_params_list
    def grad_fn(*grad_output, **kwargs):
        variables = kwargs.get('variables', None)
        f_params = tuple(tf.squeeze(v, -1) for v in variables)

        # Augmented forward integration
        def dfdx(x, t):
            with tf.GradientTape() as tape:
                tape.watch(x)
                func_out = func(x, t)
            return tape.jacobian(func_out, x, unconnected_gradients='zero')

        def dfdp(x, t):
            with tf.GradientTape() as tape:
                tape.watch(f_params)
                func_out = func(x, t)
            jac_list = tape.jacobian(func_out,
                                     f_params + (y0, ),
                                     unconnected_gradients='zero')
            return tf.concat(jac_list, axis=1)

        def aug_func(x_aug, t):
            x = x_aug[:n_states]
            dxdp = tf.reshape(x_aug[n_states:], [n_states, n_theta + n_ivs])

            dxdt = func(x, t)
            d_dxdp_dt = tf.matmul(dfdx(x, t), dxdp) + dfdp(x, t)
            return flatten([dxdt, d_dxdp_dt])

        aug_rest = np.zeros(n_states * (n_theta + n_ivs))
        for i in range(n_ivs):
            offset = n_theta * (i + 1) + n_ivs * i + i
            aug_rest[offset] = 1.0
        y0_aug = tf.cast(flatten([y0, aug_rest]), tf.float64)

        result = odeint(aug_func,
                        y0_aug,
                        t,
                        rtol=rtol,
                        atol=atol,
                        method=method,
                        options=options)
        y = result[:, :n_states]
        dydp = tf.reshape(result[:, n_states:], [T, n_states, n_theta + n_ivs])
        dydtheta = dydp[:, :, :n_theta]
        dydy0 = dydp[:, :, n_theta:]

        def vec_jac_prod(dydp, dLdy):
            dydp_T = tf.transpose(dydp, [0, 2, 1])
            if len(dLdy.shape) < len(dydp_T.shape):
                dLdy = tf.expand_dims(dLdy, -1)
            return tf.squeeze(
                tf.math.reduce_sum(tf.matmul(dydp_T, dLdy), axis=0), -1)

        grad_output = grad_output[-1]
        dLdtheta = vec_jac_prod(dydtheta, grad_output)
        dLdy0 = vec_jac_prod(dydy0, grad_output)

        dLdtheta_list = []
        beg = 0
        for v in variables:
            shape = v.shape
            size = tf.size(v)
            end = beg + size
            dLdtheta_list.append(tf.reshape(dLdtheta[beg:end], shape))
            beg = end

        return dLdy0, dLdtheta_list