def augmented_dynamics(t, y_aug): # Dynamics of the original system augmented with # the adjoint wrt y, and an integrator wrt t and args. y, adj_y = y_aug[:n_tensors], y_aug[ n_tensors:2 * n_tensors] # Ignore adj_time and adj_params. with tf.GradientTape() as tape: tape.watch(t) tape.watch(y) func_eval = func(t, y) func_eval = tf.convert_to_tensor(func_eval) gradys = -tf.stack(adj_y) if len(gradys.shape) < len(func_eval.shape): gradys = tf.expand_dims(gradys, axis=0) vjp_t, *vjp_y_and_params = tape.gradient( func_eval, (t, ) + y + f_params, output_gradients=gradys, unconnected_gradients=tf.UnconnectedGradients.ZERO) vjp_y = vjp_y_and_params[:n_tensors] vjp_params = vjp_y_and_params[n_tensors:] vjp_params = _flatten(vjp_params) if _check_len(f_params) == 0: vjp_params = tf.convert_to_tensor(0., dtype=vjp_y[0].dype) vjp_params = move_to_device(vjp_params, vjp_y[0].device) return (*func_eval, *vjp_y, vjp_t, vjp_params)
def odeint_adjoint(func, y0, t, rtol=1e-6, atol=1e-12, method=None, options=None): # We need this in order to access the variables inside this module, # since we have no other way of getting variables along the execution path. if not isinstance(func, tf.keras.Model): raise ValueError( 'func is required to be an instance of tf.keras.Model') with eager_mode(): tensor_input = False if tf.debugging.is_numeric_tensor(y0): class TupleFunc(tf.keras.Model): def __init__(self, base_func, **kwargs): super(TupleFunc, self).__init__(**kwargs) self.base_func = base_func def call(self, t, y): return (self.base_func(t, y[0]), ) tensor_input = True y0 = (y0, ) func = TupleFunc(func) # build the function to get its variables if not func.built: _ = func(t, y0) flat_params = _flatten(func.variables) global _arguments _arguments = _Arguments(func, method, options, rtol, atol) ys = OdeintAdjointMethod(*y0, t, flat_params) if tensor_input or type(ys) == tuple or type(ys) == list: ys = ys[0] return ys
def grad(*grad_output, variables=None): global _arguments flat_params = _flatten(variables) func = _arguments.func adjoint_method = _arguments.adjoint_method adjoint_rtol = _arguments.rtol adjoint_atol = _arguments.atol adjoint_options = _arguments.adjoint_options n_tensors = len(ans) f_params = tuple(variables) # TODO: use a tf.keras.Model and call odeint_adjoint to implement higher order derivatives. def augmented_dynamics(t, y_aug): # Dynamics of the original system augmented with # the adjoint wrt y, and an integrator wrt t and args. y, adj_y = y_aug[:n_tensors], y_aug[n_tensors:2 * n_tensors] # Ignore adj_time and adj_params. with tf.GradientTape() as tape: tape.watch(t) tape.watch(y) func_eval = func(t, y) func_eval = tf.convert_to_tensor(func_eval) gradys = -tf.stack(adj_y) if type(func_eval) in [list, tuple]: for eval_ix in range(len(func_eval)): if len(gradys[eval_ix].shape) < len(func_eval[eval_ix].shape): gradys[eval_ix] = tf.expand_dims(gradys[eval_ix], axis=0) else: if len(gradys.shape) < len(func_eval.shape): gradys = tf.expand_dims(gradys, axis=0) vjp_t, *vjp_y_and_params = tape.gradient( func_eval, (t,) + y + f_params, output_gradients=gradys, unconnected_gradients=tf.UnconnectedGradients.ZERO ) vjp_y = vjp_y_and_params[:n_tensors] vjp_params = vjp_y_and_params[n_tensors:] vjp_params = _flatten(vjp_params) if _check_len(f_params) == 0: vjp_params = tf.convert_to_tensor(0., dtype=vjp_y[0].dype) vjp_params = move_to_device(vjp_params, vjp_y[0].device) return (*func_eval, *vjp_y, vjp_t, vjp_params) T = ans[0].shape[0] if isinstance(grad_output, (tf.Tensor, tf.Variable)): adj_y = [grad_output[-1]] else: adj_y = tuple(grad_output_[-1] for grad_output_ in grad_output) adj_params = tf.zeros_like(flat_params, dtype=flat_params.dtype) adj_time = move_to_device(tf.convert_to_tensor(0., dtype=t.dtype), t.device) time_vjps = [] for i in range(T - 1, 0, -1): ans_i = tuple(ans_[i] for ans_ in ans) if isinstance(grad_output, (tf.Tensor, tf.Variable)): grad_output_i = [grad_output[i]] else: grad_output_i = tuple(grad_output_[i] for grad_output_ in grad_output) func_i = func(t[i], ans_i) if not isinstance(func_i, Iterable): func_i = [func_i] # Compute the effect of moving the current time measurement point. dLd_cur_t = sum( tf.reshape(tf.matmul(tf.reshape(func_i_, [1, -1]), tf.reshape(grad_output_i_, [-1, 1])), [1]) for func_i_, grad_output_i_ in zip(func_i, grad_output_i) ) adj_time = adj_time - dLd_cur_t time_vjps.append(dLd_cur_t) # Run the augmented system backwards in time. if isinstance(adj_params, Iterable): if _numel(adj_params) == 0: adj_params = move_to_device(tf.convert_to_tensor(0., dtype=adj_y[0].dtype), adj_y[0].device) aug_y0 = (*ans_i, *adj_y, adj_time, adj_params) aug_ans = odeint( augmented_dynamics, aug_y0, tf.convert_to_tensor([t[i], t[i - 1]]), rtol=adjoint_rtol, atol=adjoint_atol, method=adjoint_method, options=adjoint_options ) # Unpack aug_ans. adj_y = aug_ans[n_tensors:2 * n_tensors] adj_time = aug_ans[2 * n_tensors] adj_params = aug_ans[2 * n_tensors + 1] adj_y = tuple(adj_y_[1] if _check_len(adj_y_) > 0 else adj_y_ for adj_y_ in adj_y) if _check_len(adj_time) > 0: adj_time = adj_time[1] if _check_len(adj_params) > 0: adj_params = adj_params[1] adj_y = tuple(adj_y_ + grad_output_[i - 1] for adj_y_, grad_output_ in zip(adj_y, grad_output)) del aug_y0, aug_ans time_vjps.append(adj_time) time_vjps = tf.concat(time_vjps[::-1], 0) # reshape the parameters back into the correct variable shapes var_flat_lens = [_numel(v, dtype=tf.int32).numpy() for v in variables] var_shapes = [v.shape for v in variables] adj_params_splits = tf.split(adj_params, var_flat_lens) adj_params_list = [tf.reshape(p, v_shape) for p, v_shape in zip(adj_params_splits, var_shapes)] return (*adj_y, time_vjps), adj_params_list