def integrate(self, t): _assert_increasing(t) solution = [cast_double(self.y0)] t = move_to_device(tf.cast(t, tf.float64), self.y0[0].device) self.before_integrate(t) for i in range(1, t.shape[0]): y = self.advance(t[i]) y = cast_double(y) solution.append(y) return tuple(map(tf.stack, tuple(zip(*solution))))
def _interp_fit_dopri5(y0, y1, k, dt, tableau=_DORMAND_PRINCE_SHAMPINE_TABLEAU): """Fit an interpolating polynomial to the results of a Runge-Kutta step.""" dt = cast_double(dt) y0 = cast_double(y0) y_mid = tuple(y0_ + _scaled_dot_product(dt, DPS_C_MID, k_) for y0_, k_ in zip(y0, k)) f0 = tuple(k_[0] for k_ in k) f1 = tuple(k_[-1] for k_ in k) return _interp_fit(y0, y1, y_mid, f0, f1, dt)
def augmented_dynamics(t, y_aug): # Dynamics of the original system augmented with # the adjoint wrt y, and an integrator wrt t and args. y, adj_y = y_aug[:n_tensors], y_aug[ n_tensors:2 * n_tensors] # Ignore adj_time and adj_params. with tf.GradientTape() as tape: tape.watch(t) tape.watch(y) func_eval = func(t, y) func_eval = cast_double(func_eval) gradys = tf.stack(list(-adj_y_ for adj_y_ in adj_y)) if len(gradys.shape) < len(func_eval.shape): gradys = tf.expand_dims(gradys, axis=0) vjp_t, *vjp_y_and_params = tape.gradient(func_eval, (t, ) + y + f_params, output_gradients=gradys) vjp_y = vjp_y_and_params[:n_tensors] vjp_params = vjp_y_and_params[n_tensors:] # autograd.grad returns None if no gradient, set to zero. vjp_t = tf.zeros_like(t, dtype=t.dtype) if vjp_t is None else vjp_t vjp_y = tuple( tf.zeros_like(y_, dtype=y_.dtype) if vjp_y_ is None else vjp_y_ for vjp_y_, y_ in zip(vjp_y, y)) vjp_params = _flatten_convert_none_to_zeros(vjp_params, f_params) if _check_len(f_params) == 0: vjp_params = tf.convert_to_tensor(0., dtype=vjp_y[0].dype) vjp_params = move_to_device(vjp_params, vjp_y[0].device) return (*func_eval, *vjp_y, vjp_t, vjp_params)
def _runge_kutta_step(func, y0, f0, t0, dt, tableau): """Take an arbitrary Runge-Kutta step and estimate error. Args: func: Function to evaluate like `func(t, y)` to compute the time derivative of `y`. y0: Tensor initial value for the state. f0: Tensor initial value for the derivative, computed from `func(t0, y0)`. t0: float64 scalar Tensor giving the initial time. dt: float64 scalar Tensor giving the size of the desired time step. tableau: optional _ButcherTableau describing how to take the Runge-Kutta step. name: optional name for the operation. Returns: Tuple `(y1, f1, y1_error, k)` giving the estimated function value after the Runge-Kutta step at `t1 = t0 + dt`, the derivative of the state at `t1`, estimated error at `t1`, and a list of Runge-Kutta coefficients `k` used for calculating these terms. """ y0 = cast_double(y0) f0 = cast_double(f0) dtype = y0[0].dtype device = y0[0].device t0 = _convert_to_tensor(t0, dtype=dtype, device=device) dt = _convert_to_tensor(dt, dtype=dtype, device=device) k = tuple(map(lambda x: [x], f0)) for alpha_i, beta_i in zip(tableau.alpha, tableau.beta): ti = t0 + alpha_i * dt yi = tuple(y0_ + _scaled_dot_product(dt, cast_double(beta_i), k_) for y0_, k_ in zip(y0, k)) tuple(k_.append(cast_double(f_)) for k_, f_ in zip(k, func(ti, yi))) if not (tableau.c_sol[-1] == 0 and tableau.c_sol[:-1] == tableau.beta[-1]): # This property (true for Dormand-Prince) lets us save a few FLOPs. yi = tuple(y0_ + _scaled_dot_product(dt, tableau.c_sol, k_) for y0_, k_ in zip(y0, k)) y1 = yi f1 = tuple(k_[-1] for k_ in k) y1_error = tuple(_scaled_dot_product(dt, tableau.c_error, k_) for k_ in k) return (y1, f1, y1_error, k)
def _interp_coeff_tsit5(t0, dt, eval_t): t = cast_double((eval_t - t0) / dt) b1 = -1.0530884977290216 * t * (t - 1.3299890189751412) * (t ** 2 - 1.4364028541716351 * t + 0.7139816917074209) b2 = 0.1017 * t ** 2 * (t ** 2 - 2.1966568338249754 * t + 1.2949852507374631) b3 = 2.490627285651252793 * t ** 2 * (t ** 2 - 2.38535645472061657 * t + 1.57803468208092486) b4 = -16.54810288924490272 * (t - 1.21712927295533244) * (t - 0.61620406037800089) * t ** 2 b5 = 47.37952196281928122 * (t - 1.203071208372362603) * (t - 0.658047292653547382) * t ** 2 b6 = -34.87065786149660974 * (t - 1.2) * (t - 0.666666666666666667) * t ** 2 b7 = 2.5 * (t - 1) * (t - 0.6) * t ** 2 return [b1, b2, b3, b4, b5, b6, b7]
def rk4_step_func(func, t, dt, y, k1=None): if k1 is None: k1 = func(t, y) k1 = cast_double(k1) k2 = func(t + dt / 2, tuple(y_ + dt * k1_ / 2 for y_, k1_ in zip(y, k1))) k3 = func(t + dt / 2, tuple(y_ + dt * k2_ / 2 for y_, k2_ in zip(y, k2))) k4 = func(t + dt, tuple(y_ + dt * k3_ for y_, k3_ in zip(y, k3))) return tuple((k1_ + 2 * k2_ + 2 * k3_ + k4_) * (dt / 6) for k1_, k2_, k3_, k4_ in zip(k1, k2, k3, k4))
def before_integrate(self, t): if self.first_step is None: first_step = _convert_to_tensor(_select_initial_step(self.func, t[0], self.y0, 4, self.rtol, self.atol), device=t.device) else: first_step = _convert_to_tensor(0.01, dtype=t.dtype, device=t.device) self.rk_state = _RungeKuttaState( self.y0, cast_double(self.func(t[0], self.y0)), t[0], t[0], first_step, tuple(map(lambda x: [x] * 7, self.y0)) )
def rk4_alt_step_func(func, t, dt, y, k1=None): """Smaller error with slightly more compute.""" if k1 is None: k1 = func(t, y) k1 = cast_double(k1) k2 = func(t + dt / 3, tuple(y_ + dt * k1_ / 3 for y_, k1_ in zip(y, k1))) k2 = cast_double(k2) k3 = func( t + dt * 2 / 3, tuple(y_ + dt * (k1_ / -3 + k2_) for y_, k1_, k2_ in zip(y, k1, k2))) k3 = cast_double(k3) k4 = func( t + dt, tuple(y_ + dt * (k1_ - k2_ + k3_) for y_, k1_, k2_, k3_ in zip(y, k1, k2, k3))) k4 = cast_double(k4) return tuple((k1_ + 3 * k2_ + 3 * k3_ + k4_) * (dt / 8) for k1_, k2_, k3_, k4_ in zip(k1, k2, k3, k4))
def integrate(self, t): _assert_increasing(t) t = tf.cast(t, self.y0[0].dtype) time_grid = self.grid_constructor(self.func, self.y0, t) assert tf.equal(time_grid[0], t[0]) and tf.equal(time_grid[-1], t[-1]) time_grid = move_to_device(time_grid, self.y0[0].device) solution = [cast_double(self.y0)] j = 1 y0 = cast_double(self.y0) for t0, t1 in zip(time_grid[:-1], time_grid[1:]): dy = self.step_func(self.func, t0, t1 - t0, y0) y1 = tuple(y0_ + dy_ for y0_, dy_ in zip(y0, dy)) y0 = y1 while j < t.shape[0] and t1 >= t[j]: y = self._linear_interp(t0, t1, y0, y1, t[j]) solution.append(y) j += 1 return tuple(map(tf.stack, tuple(zip(*solution))))
def augmented_dynamics(t, y_aug): # Dynamics of the original system augmented with # the adjoint wrt y, and an integrator wrt t and args. y, adj_y = y_aug[:n_tensors], y_aug[ n_tensors:2 * n_tensors] # Ignore adj_time and adj_params. # t = tf.get_variable('t', initializer=t) # y = tuple(tf.Variable(y_) for y_ in y) with tf.GradientTape() as tape: tape.watch(t) tape.watch(y) func_eval = func(t, y) func_eval = cast_double(func_eval) # print('y', [y_.numpy().shape for y_ in y]) # print('adj y', [a.numpy().shape for a in adj_y]) vjp_t, *vjp_y_and_params = tape.gradient( func_eval, (t, ) + y + f_params, # list(-adj_y_ for adj_y_ in adj_y), ) vjp_y = vjp_y_and_params[:n_tensors] vjp_params = vjp_y_and_params[n_tensors:] # print('vjp_y', [v.numpy().shape if v is not None else None for v in vjp_y]) # print() # autograd.grad returns None if no gradient, set to zero. vjp_t = tf.zeros_like(t, dtype=t.dtype) if vjp_t is None else vjp_t vjp_y = tuple( tf.zeros_like(y_, dtype=y_.dtype) if vjp_y_ is None else vjp_y_ for vjp_y_, y_ in zip(vjp_y, y)) vjp_params = _flatten_convert_none_to_zeros(vjp_params, f_params) if _check_len(f_params) == 0: vjp_params = tf.convert_to_tensor(0., dtype=vjp_y[0].dype) vjp_params = move_to_device(vjp_params, vjp_y[0].device) # print('vjp_t grad', vjp_t.numpy()) # print('vjp_params', [v.numpy() for v in vjp_params]) # print('vjp y grads', [v.numpy().shape for v in vjp_y]) # print("LEN FUNC EVALS : ", len(func_eval)) # print() return (*func_eval, *vjp_y, vjp_t, vjp_params)
def _interp_eval_tsit5(t0, t1, k, eval_t): dt = cast_double(t1) - cast_double(t0) y0 = tuple(k_[0] for k_ in k) interp_coeff = _interp_coeff_tsit5(t0, dt, eval_t) y_t = tuple(y0_ + _scaled_dot_product(dt, interp_coeff, k_) for y0_, k_ in zip(y0, k)) return y_t
def step_func(self, func, t, dt, y): return tuple(dt * f_ for f_ in cast_double(func(t, y)))
def step_func(self, func, t, dt, y): f_outs = cast_double(func(t, y)) ft_1_hat = tuple(y_ + dt * f_ for y_, f_ in zip(y, f_outs)) ft_1_outs = cast_double(func(t + dt, ft_1_hat)) return tuple(dt / 2. * (ft_ + ft_1_hat_) for ft_, ft_1_hat_ in zip(f_outs, ft_1_outs))
def step_func(self, func, t, dt, y): y_mid = tuple(y_ + f_ * dt / 2 for y_, f_ in zip(y, cast_double(func(t, y)))) return tuple(dt * f_ for f_ in cast_double(func(t + dt / 2, y_mid)))
def grad(*grad_output, variables=None): global _arguments # t, flat_params, *ans = ctx.saved_tensors # ans = tuple(ans) # func, rtol, atol, method, options = ctx.func, ctx.rtol, ctx.atol, ctx.method, ctx.options func = _arguments.func method = _arguments.method options = _arguments.options rtol = _arguments.rtol atol = _arguments.atol print("Gradient Output : ", grad_output) print("Variables : ", variables) n_tensors = len(ans) f_params = tuple(variables) # TODO: use a tf.keras.Model and call odeint_adjoint to implement higher order derivatives. def augmented_dynamics(t, y_aug): # Dynamics of the original system augmented with # the adjoint wrt y, and an integrator wrt t and args. y, adj_y = y_aug[:n_tensors], y_aug[ n_tensors:2 * n_tensors] # Ignore adj_time and adj_params. # t = tf.get_variable('t', initializer=t) # y = tuple(tf.Variable(y_) for y_ in y) with tf.GradientTape() as tape: tape.watch(t) tape.watch(y) func_eval = func(t, y) func_eval = cast_double(func_eval) # print('y', [y_.numpy().shape for y_ in y]) # print('adj y', [a.numpy().shape for a in adj_y]) vjp_t, *vjp_y_and_params = tape.gradient( func_eval, (t, ) + y + f_params, # list(-adj_y_ for adj_y_ in adj_y), ) vjp_y = vjp_y_and_params[:n_tensors] vjp_params = vjp_y_and_params[n_tensors:] # print('vjp_y', [v.numpy().shape if v is not None else None for v in vjp_y]) # print() # autograd.grad returns None if no gradient, set to zero. vjp_t = tf.zeros_like(t, dtype=t.dtype) if vjp_t is None else vjp_t vjp_y = tuple( tf.zeros_like(y_, dtype=y_.dtype) if vjp_y_ is None else vjp_y_ for vjp_y_, y_ in zip(vjp_y, y)) vjp_params = _flatten_convert_none_to_zeros(vjp_params, f_params) if _check_len(f_params) == 0: vjp_params = tf.convert_to_tensor(0., dtype=vjp_y[0].dype) vjp_params = move_to_device(vjp_params, vjp_y[0].device) # print('vjp_t grad', vjp_t.numpy()) # print('vjp_params', [v.numpy() for v in vjp_params]) # print('vjp y grads', [v.numpy().shape for v in vjp_y]) # print("LEN FUNC EVALS : ", len(func_eval)) # print() return (*func_eval, *vjp_y, vjp_t, vjp_params) T = ans[0].shape[0] if isinstance(grad_output, tf.Tensor) or isinstance( grad_output, tf.Variable): adj_y = [grad_output[-1]] else: adj_y = tuple(grad_output_[-1] for grad_output_ in grad_output) # adj_y = tuple(grad_output_[-1] for grad_output_ in grad_output) adj_params = tf.zeros_like(flat_params, dtype=flat_params.dtype) adj_time = move_to_device(tf.convert_to_tensor(0., dtype=t.dtype), t.device) time_vjps = [] for i in range(T - 1, 0, -1): ans_i = tuple(ans_[i] for ans_ in ans) if isinstance(grad_output, tf.Tensor) or isinstance( grad_output, tf.Variable): grad_output_i = [grad_output[i]] else: grad_output_i = tuple(grad_output_[i] for grad_output_ in grad_output) func_i = func(t[i], ans_i) func_i = cast_double(func_i) if not isinstance(func_i, Iterable): func_i = [func_i] # Compute the effect of moving the current time measurement point. dLd_cur_t = sum( tf.reshape( tf.matmul(tf.reshape(func_i_, [1, -1]), tf.reshape(grad_output_i_, [-1, 1])), [1]) for func_i_, grad_output_i_ in zip(func_i, grad_output_i)) adj_time = cast_double(adj_time) adj_time = adj_time - dLd_cur_t time_vjps.append(dLd_cur_t) # Run the augmented system backwards in time. if isinstance(adj_params, Iterable): count = _numel(adj_params) if count == 0: adj_params = move_to_device( tf.convert_to_tensor(0., dtype=adj_y[0].dtype), adj_y[0].device) aug_y0 = (*ans_i, *adj_y, adj_time, adj_params) # print('ans i', [a.numpy().shape for a in ans_i]) # print('adj y', [a.numpy().shape for a in adj_y]) # print('adj time', adj_time.numpy().shape) # print('adj params', adj_params.numpy().shape) # print() aug_ans = odeint(augmented_dynamics, aug_y0, tf.convert_to_tensor([t[i], t[i - 1]]), rtol=rtol, atol=atol, method=method, options=options) # Unpack aug_ans. adj_y = aug_ans[n_tensors:2 * n_tensors] adj_time = aug_ans[2 * n_tensors] adj_params = aug_ans[2 * n_tensors + 1] adj_y = tuple(adj_y_[1] if _check_len(adj_y_) > 0 else adj_y_ for adj_y_ in adj_y) if _check_len(adj_time) > 0: adj_time = adj_time[1] if _check_len(adj_params) > 0: adj_params = adj_params[1] adj_y = tuple(adj_y_ + grad_output_[i - 1] for adj_y_, grad_output_ in zip(adj_y, grad_output)) del aug_y0, aug_ans time_vjps.append(adj_time) time_vjps = tf.concat(time_vjps[::-1], 0) print() print('adj y', len(adj_y)) print('time vjps', time_vjps.shape) print('adj params', adj_params.shape) print() # reshape the parameters back into the correct variable shapes var_flat_lens = [_numel(v, dtype=tf.int32).numpy() for v in variables] var_shapes = [v.shape for v in variables] adj_params_splits = tf.split(adj_params, var_flat_lens) adj_params_list = [ tf.reshape(p, v_shape) for p, v_shape in zip(adj_params_splits, var_shapes) ] # add the time gradient (always the first tensor in list of variables) # adj_params.insert(0, time_vjps) # adj_y_grad_vars = list(zip(adj_y, grad_output)) # time_grad_vars = list((time_vjps, t)) model_vars = list(adj_params_list) # list(zip(adj_params, variables)) grad_vars = model_vars # adj_y_grad_vars + time_grad_vars + model_vars # print('adj y grad', len(adj_y_grad_vars)) # print('time grad', len(time_grad_vars)) print('model grad', len(model_vars)) print('model grad values', [v for v in grad_vars]) print() # if len(adj_y) == 1: # adj_y = adj_y[0] return (adj_y, model_vars)
def grad(*grad_output, variables=None): global _arguments flat_params = _flatten(variables) func = _arguments.func method = _arguments.method options = _arguments.options rtol = _arguments.rtol atol = _arguments.atol n_tensors = len(ans) f_params = tuple(variables) # TODO: use a tf.keras.Model and call odeint_adjoint to implement higher order derivatives. def augmented_dynamics(t, y_aug): # Dynamics of the original system augmented with # the adjoint wrt y, and an integrator wrt t and args. y, adj_y = y_aug[:n_tensors], y_aug[ n_tensors:2 * n_tensors] # Ignore adj_time and adj_params. with tf.GradientTape() as tape: tape.watch(t) tape.watch(y) func_eval = func(t, y) func_eval = cast_double(func_eval) gradys = tf.stack(list(-adj_y_ for adj_y_ in adj_y)) if len(gradys.shape) < len(func_eval.shape): gradys = tf.expand_dims(gradys, axis=0) vjp_t, *vjp_y_and_params = tape.gradient(func_eval, (t, ) + y + f_params, output_gradients=gradys) vjp_y = vjp_y_and_params[:n_tensors] vjp_params = vjp_y_and_params[n_tensors:] # autograd.grad returns None if no gradient, set to zero. vjp_t = tf.zeros_like(t, dtype=t.dtype) if vjp_t is None else vjp_t vjp_y = tuple( tf.zeros_like(y_, dtype=y_.dtype) if vjp_y_ is None else vjp_y_ for vjp_y_, y_ in zip(vjp_y, y)) vjp_params = _flatten_convert_none_to_zeros(vjp_params, f_params) if _check_len(f_params) == 0: vjp_params = tf.convert_to_tensor(0., dtype=vjp_y[0].dype) vjp_params = move_to_device(vjp_params, vjp_y[0].device) return (*func_eval, *vjp_y, vjp_t, vjp_params) T = ans[0].shape[0] if isinstance(grad_output, tf.Tensor) or isinstance( grad_output, tf.Variable): adj_y = [grad_output[-1]] else: adj_y = tuple(grad_output_[-1] for grad_output_ in grad_output) # adj_y = tuple(grad_output_[-1] for grad_output_ in grad_output) adj_params = tf.zeros_like(flat_params, dtype=flat_params.dtype) adj_time = move_to_device(tf.convert_to_tensor(0., dtype=t.dtype), t.device) time_vjps = [] for i in range(T - 1, 0, -1): ans_i = tuple(ans_[i] for ans_ in ans) if isinstance(grad_output, tf.Tensor) or isinstance( grad_output, tf.Variable): grad_output_i = [grad_output[i]] else: grad_output_i = tuple(grad_output_[i] for grad_output_ in grad_output) func_i = func(t[i], ans_i) func_i = cast_double(func_i) if not isinstance(func_i, Iterable): func_i = [func_i] # Compute the effect of moving the current time measurement point. dLd_cur_t = sum( tf.reshape( tf.matmul(tf.reshape(func_i_, [1, -1]), tf.reshape(grad_output_i_, [-1, 1])), [1]) for func_i_, grad_output_i_ in zip(func_i, grad_output_i)) adj_time = cast_double(adj_time) adj_time = adj_time - dLd_cur_t time_vjps.append(dLd_cur_t) # Run the augmented system backwards in time. if isinstance(adj_params, Iterable): count = _numel(adj_params) if count == 0: adj_params = move_to_device( tf.convert_to_tensor(0., dtype=adj_y[0].dtype), adj_y[0].device) aug_y0 = (*ans_i, *adj_y, adj_time, adj_params) aug_ans = odeint(augmented_dynamics, aug_y0, tf.convert_to_tensor([t[i], t[i - 1]]), rtol=rtol, atol=atol, method=method, options=options) # Unpack aug_ans. adj_y = aug_ans[n_tensors:2 * n_tensors] adj_time = aug_ans[2 * n_tensors] adj_params = aug_ans[2 * n_tensors + 1] adj_y = tuple(adj_y_[1] if _check_len(adj_y_) > 0 else adj_y_ for adj_y_ in adj_y) if _check_len(adj_time) > 0: adj_time = adj_time[1] if _check_len(adj_params) > 0: adj_params = adj_params[1] adj_y = tuple(adj_y_ + grad_output_[i - 1] for adj_y_, grad_output_ in zip(adj_y, grad_output)) del aug_y0, aug_ans time_vjps.append(adj_time) time_vjps = tf.concat(time_vjps[::-1], 0) # reshape the parameters back into the correct variable shapes var_flat_lens = [_numel(v, dtype=tf.int32).numpy() for v in variables] var_shapes = [v.shape for v in variables] adj_params_splits = tf.split(adj_params, var_flat_lens) adj_params_list = [ tf.reshape(p, v_shape) for p, v_shape in zip(adj_params_splits, var_shapes) ] return (*adj_y, time_vjps), adj_params_list