コード例 #1
0
        def augmented_dynamics(t, y_aug):
            # Dynamics of the original system augmented with
            # the adjoint wrt y, and an integrator wrt t and args.

            y, adj_y = y_aug[:n_tensors], y_aug[
                n_tensors:2 * n_tensors]  # Ignore adj_time and adj_params.

            with tf.GradientTape() as tape:
                tape.watch(t)
                tape.watch(y)
                func_eval = func(t, y)
                func_eval = tf.convert_to_tensor(func_eval)

            gradys = -tf.stack(adj_y)
            if len(gradys.shape) < len(func_eval.shape):
                gradys = tf.expand_dims(gradys, axis=0)
            vjp_t, *vjp_y_and_params = tape.gradient(
                func_eval, (t, ) + y + f_params,
                output_gradients=gradys,
                unconnected_gradients=tf.UnconnectedGradients.ZERO)

            vjp_y = vjp_y_and_params[:n_tensors]
            vjp_params = vjp_y_and_params[n_tensors:]
            vjp_params = _flatten(vjp_params)

            if _check_len(f_params) == 0:
                vjp_params = tf.convert_to_tensor(0., dtype=vjp_y[0].dype)
                vjp_params = move_to_device(vjp_params, vjp_y[0].device)

            return (*func_eval, *vjp_y, vjp_t, vjp_params)
コード例 #2
0
def odeint_adjoint(func,
                   y0,
                   t,
                   rtol=1e-6,
                   atol=1e-12,
                   method=None,
                   options=None):
    # We need this in order to access the variables inside this module,
    # since we have no other way of getting variables along the execution path.
    if not isinstance(func, tf.keras.Model):
        raise ValueError(
            'func is required to be an instance of tf.keras.Model')

    with eager_mode():
        tensor_input = False
        if tf.debugging.is_numeric_tensor(y0):

            class TupleFunc(tf.keras.Model):
                def __init__(self, base_func, **kwargs):
                    super(TupleFunc, self).__init__(**kwargs)
                    self.base_func = base_func

                def call(self, t, y):
                    return (self.base_func(t, y[0]), )

            tensor_input = True
            y0 = (y0, )
            func = TupleFunc(func)

        # build the function to get its variables
        if not func.built:
            _ = func(t, y0)

        flat_params = _flatten(func.variables)

        global _arguments
        _arguments = _Arguments(func, method, options, rtol, atol)

        ys = OdeintAdjointMethod(*y0, t, flat_params)

        if tensor_input or type(ys) == tuple or type(ys) == list:
            ys = ys[0]

        return ys
コード例 #3
0
    def grad(*grad_output, variables=None):
        global _arguments
        flat_params = _flatten(variables)

        func = _arguments.func
        adjoint_method = _arguments.adjoint_method
        adjoint_rtol = _arguments.rtol
        adjoint_atol = _arguments.atol
        adjoint_options = _arguments.adjoint_options

        n_tensors = len(ans)
        f_params = tuple(variables)

        # TODO: use a tf.keras.Model and call odeint_adjoint to implement higher order derivatives.
        def augmented_dynamics(t, y_aug):
            # Dynamics of the original system augmented with
            # the adjoint wrt y, and an integrator wrt t and args.

            y, adj_y = y_aug[:n_tensors], y_aug[n_tensors:2 * n_tensors]  # Ignore adj_time and adj_params.

            with tf.GradientTape() as tape:
                tape.watch(t)
                tape.watch(y)
                func_eval = func(t, y)
                func_eval = tf.convert_to_tensor(func_eval)

            gradys = -tf.stack(adj_y)
            if type(func_eval) in [list, tuple]:
                for eval_ix in range(len(func_eval)):
                    if len(gradys[eval_ix].shape) < len(func_eval[eval_ix].shape):
                        gradys[eval_ix] = tf.expand_dims(gradys[eval_ix], axis=0)

            else:
                if len(gradys.shape) < len(func_eval.shape):
                    gradys = tf.expand_dims(gradys, axis=0)
            vjp_t, *vjp_y_and_params = tape.gradient(
                func_eval,
                (t,) + y + f_params,
                output_gradients=gradys,
                unconnected_gradients=tf.UnconnectedGradients.ZERO
            )

            vjp_y = vjp_y_and_params[:n_tensors]
            vjp_params = vjp_y_and_params[n_tensors:]
            vjp_params = _flatten(vjp_params)

            if _check_len(f_params) == 0:
                vjp_params = tf.convert_to_tensor(0., dtype=vjp_y[0].dype)
                vjp_params = move_to_device(vjp_params, vjp_y[0].device)

            return (*func_eval, *vjp_y, vjp_t, vjp_params)

        T = ans[0].shape[0]
        if isinstance(grad_output, (tf.Tensor, tf.Variable)):
            adj_y = [grad_output[-1]]
        else:
            adj_y = tuple(grad_output_[-1] for grad_output_ in grad_output)

        adj_params = tf.zeros_like(flat_params, dtype=flat_params.dtype)
        adj_time = move_to_device(tf.convert_to_tensor(0., dtype=t.dtype), t.device)
        time_vjps = []
        for i in range(T - 1, 0, -1):

            ans_i = tuple(ans_[i] for ans_ in ans)

            if isinstance(grad_output, (tf.Tensor, tf.Variable)):
                grad_output_i = [grad_output[i]]
            else:
                grad_output_i = tuple(grad_output_[i] for grad_output_ in grad_output)

            func_i = func(t[i], ans_i)

            if not isinstance(func_i, Iterable):
                func_i = [func_i]

            # Compute the effect of moving the current time measurement point.
            dLd_cur_t = sum(
                tf.reshape(tf.matmul(tf.reshape(func_i_, [1, -1]), tf.reshape(grad_output_i_, [-1, 1])), [1])
                for func_i_, grad_output_i_ in zip(func_i, grad_output_i)
            )

            adj_time = adj_time - dLd_cur_t
            time_vjps.append(dLd_cur_t)

            # Run the augmented system backwards in time.
            if isinstance(adj_params, Iterable):
                if _numel(adj_params) == 0:
                    adj_params = move_to_device(tf.convert_to_tensor(0., dtype=adj_y[0].dtype), adj_y[0].device)

            aug_y0 = (*ans_i, *adj_y, adj_time, adj_params)

            aug_ans = odeint(
                augmented_dynamics,
                aug_y0,
                tf.convert_to_tensor([t[i], t[i - 1]]),
                rtol=adjoint_rtol, atol=adjoint_atol, method=adjoint_method, options=adjoint_options
            )

            # Unpack aug_ans.
            adj_y = aug_ans[n_tensors:2 * n_tensors]
            adj_time = aug_ans[2 * n_tensors]
            adj_params = aug_ans[2 * n_tensors + 1]

            adj_y = tuple(adj_y_[1] if _check_len(adj_y_) > 0 else adj_y_ for adj_y_ in adj_y)
            if _check_len(adj_time) > 0: adj_time = adj_time[1]
            if _check_len(adj_params) > 0: adj_params = adj_params[1]

            adj_y = tuple(adj_y_ + grad_output_[i - 1] for adj_y_, grad_output_ in zip(adj_y, grad_output))

            del aug_y0, aug_ans

        time_vjps.append(adj_time)
        time_vjps = tf.concat(time_vjps[::-1], 0)

        # reshape the parameters back into the correct variable shapes
        var_flat_lens = [_numel(v, dtype=tf.int32).numpy() for v in variables]
        var_shapes = [v.shape for v in variables]

        adj_params_splits = tf.split(adj_params, var_flat_lens)
        adj_params_list = [tf.reshape(p, v_shape)
                           for p, v_shape in zip(adj_params_splits, var_shapes)]
        return (*adj_y, time_vjps), adj_params_list