def g_and_explicit_phi(prev_t, next_t, implicit_phi, k): curr_t = prev_t[0] dt = next_t - prev_t[0] with tf.device(prev_t[0].device): g = tf.Variable(tf.zeros([k + 1]), trainable=False) explicit_phi = collections.deque(maxlen=k) beta = move_to_device(tf.convert_to_tensor(1.), prev_t[0].device) # tf.assign(g[0], 1) compat.assign(g[0], 1) c = 1 / move_to_device(tf.range(1, k + 2), prev_t[0].device) explicit_phi.append(implicit_phi[0]) beta = tf.cast(beta, next_t.dtype) for j in range(1, k): beta = (next_t - prev_t[j - 1]) / (curr_t - prev_t[j]) * beta beta_cast = move_to_device(beta, implicit_phi[j][0].device) beta_cast = tf.cast(beta_cast, implicit_phi[0][0].dtype) explicit_phi.append( tuple(iphi_ * beta_cast for iphi_ in implicit_phi[j])) c = c[:-1] - c[1:] if j == 1 else c[:-1] - c[1:] * dt / (next_t - prev_t[j - 1]) # tf.assign(g[j], tf.cast(c[0], g[j].dtype)) compat.assign(g[j], tf.cast(c[0], g[j].dtype)) # g[j] = c[0] c = c[:-1] - c[1:] * dt / (next_t - prev_t[k - 1]) # tf.assign(g[k], tf.cast(c[0], g[k].dtype)) compat.assign(g[k], tf.cast(c[0], g[k].dtype)) return g, explicit_phi
def _interp_evaluate(coefficients, t0, t1, t): """Evaluate polynomial interpolation at the given time point. Args: coefficients: list of Tensor coefficients as created by `interp_fit`. t0: scalar float64 Tensor giving the start of the interval. t1: scalar float64 Tensor giving the end of the interval. t: scalar float64 Tensor giving the desired interpolation point. Returns: Polynomial interpolation of the coefficients at time `t`. """ dtype = coefficients[0][0].dtype device = coefficients[0][0].device t0 = _convert_to_tensor(t0, dtype=dtype, device=device) t1 = _convert_to_tensor(t1, dtype=dtype, device=device) t = _convert_to_tensor(t, dtype=dtype, device=device) assert (t0 <= t) & ( t <= t1), 'invalid interpolation, fails `t0 <= t <= t1`: {}, {}, {}'.format( t0, t, t1) x = tf.cast(((t - t0) / (t1 - t0)), dtype) x = move_to_device(x, device) xs = [move_to_device(tf.convert_to_tensor(1, dtype=dtype), device), x] for _ in range(2, len(coefficients)): xs.append(xs[-1] * x) return tuple( _dot_product(coefficients_, reversed(xs)) for coefficients_ in zip(*coefficients))
def _linear_interp(self, t0, t1, y0, y1, t): if t == t0: return y0 if t == t1: return y1 t0 = move_to_device(t0, y0[0].device) t1 = move_to_device(t1, y0[0].device) t = move_to_device(t, y0[0].device) slope = tuple((y1_ - y0_) / (t1 - t0) for y0_, y1_, in zip(y0, y1)) return tuple(y0_ + slope_ * (t - t0) for y0_, slope_ in zip(y0, slope))
def before_integrate(self, t): prev_f = collections.deque(maxlen=self.max_order + 1) prev_t = collections.deque(maxlen=self.max_order + 1) phi = collections.deque(maxlen=self.max_order) t0 = t[0] f0 = self.func(tf.cast(t0, self.y0[0].dtype), self.y0) prev_t.appendleft(t0) prev_f.appendleft(f0) phi.appendleft(f0) first_step = _select_initial_step(self.func, t[0], self.y0, 2, self.rtol[0], self.atol[0], f0=f0) first_step = move_to_device(first_step, t.device) first_step = tf.cast(first_step, t[0].dtype) self.vcabm_state = _VCABMState(self.y0, prev_f, prev_t, next_t=t[0] + first_step, phi=phi, order=1)
def augmented_dynamics(t, y_aug): # Dynamics of the original system augmented with # the adjoint wrt y, and an integrator wrt t and args. y, adj_y = y_aug[:n_tensors], y_aug[ n_tensors:2 * n_tensors] # Ignore adj_time and adj_params. with tf.GradientTape() as tape: tape.watch(t) tape.watch(y) func_eval = func(t, y) func_eval = tf.convert_to_tensor(func_eval) gradys = -tf.stack(adj_y) if len(gradys.shape) < len(func_eval.shape): gradys = tf.expand_dims(gradys, axis=0) vjp_t, *vjp_y_and_params = tape.gradient( func_eval, (t, ) + y + f_params, output_gradients=gradys, unconnected_gradients=tf.UnconnectedGradients.ZERO) vjp_y = vjp_y_and_params[:n_tensors] vjp_params = vjp_y_and_params[n_tensors:] vjp_params = _flatten(vjp_params) if _check_len(f_params) == 0: vjp_params = tf.convert_to_tensor(0., dtype=vjp_y[0].dype) vjp_params = move_to_device(vjp_params, vjp_y[0].device) return (*func_eval, *vjp_y, vjp_t, vjp_params)
def augmented_dynamics(t, y_aug): # Dynamics of the original system augmented with # the adjoint wrt y, and an integrator wrt t and args. y, adj_y = y_aug[:n_tensors], y_aug[ n_tensors:2 * n_tensors] # Ignore adj_time and adj_params. with tf.GradientTape() as tape: tape.watch(t) tape.watch(y) func_eval = func(t, y) func_eval = cast_double(func_eval) gradys = tf.stack(list(-adj_y_ for adj_y_ in adj_y)) if len(gradys.shape) < len(func_eval.shape): gradys = tf.expand_dims(gradys, axis=0) vjp_t, *vjp_y_and_params = tape.gradient(func_eval, (t, ) + y + f_params, output_gradients=gradys) vjp_y = vjp_y_and_params[:n_tensors] vjp_params = vjp_y_and_params[n_tensors:] # autograd.grad returns None if no gradient, set to zero. vjp_t = tf.zeros_like(t, dtype=t.dtype) if vjp_t is None else vjp_t vjp_y = tuple( tf.zeros_like(y_, dtype=y_.dtype) if vjp_y_ is None else vjp_y_ for vjp_y_, y_ in zip(vjp_y, y)) vjp_params = _flatten_convert_none_to_zeros(vjp_params, f_params) if _check_len(f_params) == 0: vjp_params = tf.convert_to_tensor(0., dtype=vjp_y[0].dype) vjp_params = move_to_device(vjp_params, vjp_y[0].device) return (*func_eval, *vjp_y, vjp_t, vjp_params)
def before_integrate(self, t): f0 = self.func(tf.cast(t[0], self.y0[0].dtype), self.y0) if self.first_step is None: first_step = _select_initial_step(self.func, t[0], self.y0, 1, self.rtol[0], self.atol[0], f0=f0) first_step = move_to_device(tf.cast(first_step, t.dtype), t.device) else: first_step = _convert_to_tensor(self.first_step, dtype=t.dtype, device=t.device) self.rk_state = _RungeKuttaState(self.y0, f0, t[0], t[0], first_step, interp_coeff=[self.y0] * 5)
def integrate(self, t): _assert_increasing(t) solution = [self.y0] t = move_to_device(tf.cast(t, tf.float64), self.y0[0].device) self.before_integrate(t) for i in range(1, t.shape[0]): y = self.advance(t[i]) solution.append(y) return tuple(map(tf.stack, tuple(zip(*solution))))
def _grid_constructor(func, y0, t): start_time = t[0] end_time = t[-1] niters = tf.ceil((end_time - start_time) / step_size + 1).item() t_infer = move_to_device(tf.range(0, niters), t) * step_size + start_time if t_infer[-1] > t[-1]: t_infer[-1] = t[-1] return t_infer
def before_integrate(self, t): if self.first_step is None: first_step = _select_initial_step(self.func, t[0], self.y0, 4, self.rtol, self.atol) first_step = move_to_device(tf.cast(first_step, t.dtype), t.device) else: first_step = _convert_to_tensor(self.first_step, dtype=t.dtype, device=t.device) self.rk_state = _RungeKuttaState(self.y0, self.func(t[0], self.y0), t[0], t[0], first_step, [self.y0] * 7)
def augmented_dynamics(t, y_aug): # Dynamics of the original system augmented with # the adjoint wrt y, and an integrator wrt t and args. y, adj_y = y_aug[:n_tensors], y_aug[ n_tensors:2 * n_tensors] # Ignore adj_time and adj_params. # t = tf.get_variable('t', initializer=t) # y = tuple(tf.Variable(y_) for y_ in y) with tf.GradientTape() as tape: tape.watch(t) tape.watch(y) func_eval = func(t, y) func_eval = cast_double(func_eval) # print('y', [y_.numpy().shape for y_ in y]) # print('adj y', [a.numpy().shape for a in adj_y]) vjp_t, *vjp_y_and_params = tape.gradient( func_eval, (t, ) + y + f_params, # list(-adj_y_ for adj_y_ in adj_y), ) vjp_y = vjp_y_and_params[:n_tensors] vjp_params = vjp_y_and_params[n_tensors:] # print('vjp_y', [v.numpy().shape if v is not None else None for v in vjp_y]) # print() # autograd.grad returns None if no gradient, set to zero. vjp_t = tf.zeros_like(t, dtype=t.dtype) if vjp_t is None else vjp_t vjp_y = tuple( tf.zeros_like(y_, dtype=y_.dtype) if vjp_y_ is None else vjp_y_ for vjp_y_, y_ in zip(vjp_y, y)) vjp_params = _flatten_convert_none_to_zeros(vjp_params, f_params) if _check_len(f_params) == 0: vjp_params = tf.convert_to_tensor(0., dtype=vjp_y[0].dype) vjp_params = move_to_device(vjp_params, vjp_y[0].device) # print('vjp_t grad', vjp_t.numpy()) # print('vjp_params', [v.numpy() for v in vjp_params]) # print('vjp y grads', [v.numpy().shape for v in vjp_y]) # print("LEN FUNC EVALS : ", len(func_eval)) # print() return (*func_eval, *vjp_y, vjp_t, vjp_params)
def integrate(self, t): _assert_increasing(t) t = tf.cast(t, self.y0[0].dtype) time_grid = self.grid_constructor(self.func, self.y0, t) assert tf.equal(time_grid[0], t[0]) and tf.equal(time_grid[-1], t[-1]) time_grid = move_to_device(time_grid, self.y0[0].device) solution = [cast_double(self.y0)] j = 1 y0 = cast_double(self.y0) for t0, t1 in zip(time_grid[:-1], time_grid[1:]): dy = self.step_func(self.func, t0, t1 - t0, y0) y1 = tuple(y0_ + dy_ for y0_, dy_ in zip(y0, dy)) y0 = y1 while j < t.shape[0] and t1 >= t[j]: y = self._linear_interp(t0, t1, y0, y1, t[j]) solution.append(y) j += 1 return tuple(map(tf.stack, tuple(zip(*solution))))
def grad(*grad_output, variables=None): global _arguments flat_params = _flatten(variables) func = _arguments.func adjoint_method = _arguments.adjoint_method adjoint_rtol = _arguments.rtol adjoint_atol = _arguments.atol adjoint_options = _arguments.adjoint_options n_tensors = len(ans) f_params = tuple(variables) # TODO: use a tf.keras.Model and call odeint_adjoint to implement higher order derivatives. def augmented_dynamics(t, y_aug): # Dynamics of the original system augmented with # the adjoint wrt y, and an integrator wrt t and args. y, adj_y = y_aug[:n_tensors], y_aug[n_tensors:2 * n_tensors] # Ignore adj_time and adj_params. with tf.GradientTape() as tape: tape.watch(t) tape.watch(y) func_eval = func(t, y) func_eval = tf.convert_to_tensor(func_eval) gradys = -tf.stack(adj_y) if type(func_eval) in [list, tuple]: for eval_ix in range(len(func_eval)): if len(gradys[eval_ix].shape) < len(func_eval[eval_ix].shape): gradys[eval_ix] = tf.expand_dims(gradys[eval_ix], axis=0) else: if len(gradys.shape) < len(func_eval.shape): gradys = tf.expand_dims(gradys, axis=0) vjp_t, *vjp_y_and_params = tape.gradient( func_eval, (t,) + y + f_params, output_gradients=gradys, unconnected_gradients=tf.UnconnectedGradients.ZERO ) vjp_y = vjp_y_and_params[:n_tensors] vjp_params = vjp_y_and_params[n_tensors:] vjp_params = _flatten(vjp_params) if _check_len(f_params) == 0: vjp_params = tf.convert_to_tensor(0., dtype=vjp_y[0].dype) vjp_params = move_to_device(vjp_params, vjp_y[0].device) return (*func_eval, *vjp_y, vjp_t, vjp_params) T = ans[0].shape[0] if isinstance(grad_output, (tf.Tensor, tf.Variable)): adj_y = [grad_output[-1]] else: adj_y = tuple(grad_output_[-1] for grad_output_ in grad_output) adj_params = tf.zeros_like(flat_params, dtype=flat_params.dtype) adj_time = move_to_device(tf.convert_to_tensor(0., dtype=t.dtype), t.device) time_vjps = [] for i in range(T - 1, 0, -1): ans_i = tuple(ans_[i] for ans_ in ans) if isinstance(grad_output, (tf.Tensor, tf.Variable)): grad_output_i = [grad_output[i]] else: grad_output_i = tuple(grad_output_[i] for grad_output_ in grad_output) func_i = func(t[i], ans_i) if not isinstance(func_i, Iterable): func_i = [func_i] # Compute the effect of moving the current time measurement point. dLd_cur_t = sum( tf.reshape(tf.matmul(tf.reshape(func_i_, [1, -1]), tf.reshape(grad_output_i_, [-1, 1])), [1]) for func_i_, grad_output_i_ in zip(func_i, grad_output_i) ) adj_time = adj_time - dLd_cur_t time_vjps.append(dLd_cur_t) # Run the augmented system backwards in time. if isinstance(adj_params, Iterable): if _numel(adj_params) == 0: adj_params = move_to_device(tf.convert_to_tensor(0., dtype=adj_y[0].dtype), adj_y[0].device) aug_y0 = (*ans_i, *adj_y, adj_time, adj_params) aug_ans = odeint( augmented_dynamics, aug_y0, tf.convert_to_tensor([t[i], t[i - 1]]), rtol=adjoint_rtol, atol=adjoint_atol, method=adjoint_method, options=adjoint_options ) # Unpack aug_ans. adj_y = aug_ans[n_tensors:2 * n_tensors] adj_time = aug_ans[2 * n_tensors] adj_params = aug_ans[2 * n_tensors + 1] adj_y = tuple(adj_y_[1] if _check_len(adj_y_) > 0 else adj_y_ for adj_y_ in adj_y) if _check_len(adj_time) > 0: adj_time = adj_time[1] if _check_len(adj_params) > 0: adj_params = adj_params[1] adj_y = tuple(adj_y_ + grad_output_[i - 1] for adj_y_, grad_output_ in zip(adj_y, grad_output)) del aug_y0, aug_ans time_vjps.append(adj_time) time_vjps = tf.concat(time_vjps[::-1], 0) # reshape the parameters back into the correct variable shapes var_flat_lens = [_numel(v, dtype=tf.int32).numpy() for v in variables] var_shapes = [v.shape for v in variables] adj_params_splits = tf.split(adj_params, var_flat_lens) adj_params_list = [tf.reshape(p, v_shape) for p, v_shape in zip(adj_params_splits, var_shapes)] return (*adj_y, time_vjps), adj_params_list
def _adaptive_adams_step(self, vcabm_state, final_t): y0, prev_f, prev_t, next_t, prev_phi, order = vcabm_state if next_t > final_t: next_t = final_t dt = (next_t - prev_t[0]) dt_cast = move_to_device(dt, y0[0].device) dt_cast = tf.cast(dt_cast, y0[0].dtype) # Explicit predictor step. g, phi = g_and_explicit_phi(prev_t, next_t, prev_phi, order) # g = move_to_device(g, y0[0].device) g = tf.cast(g, dt_cast.dtype) # phi = [tf.cast(phi_, dt_cast.dtype) for phi_ in phi] p_next = tuple( y0_ + _scaled_dot_product(dt_cast, g[:max(1, order - 1)], phi_[:max(1, order - 1)]) for y0_, phi_ in zip(y0, tuple(zip(*phi)))) # Update phi to implicit. next_t = move_to_device(next_t, p_next[0].device) next_f0 = self.func(tf.cast(next_t, self.y0[0].dtype), p_next) implicit_phi_p = compute_implicit_phi(phi, next_f0, order + 1) # Implicit corrector step. y_next = tuple( p_next_ + dt_cast * g[order - 1] * tf.cast(iphi_, dt_cast.dtype) for p_next_, iphi_ in zip(p_next, implicit_phi_p[order - 1])) # Error estimation. tolerance = tuple( atol_ + rtol_ * tf.reduce_max([tf.abs(y0_), tf.abs(y1_)]) for atol_, rtol_, y0_, y1_ in zip(self.atol, self.rtol, y0, y_next)) local_error = tuple(dt_cast * (g[order] - g[order - 1]) * tf.cast(iphi_, dt_cast.dtype) for iphi_ in implicit_phi_p[order]) error_k = _compute_error_ratio(local_error, tolerance) accept_step = tf.reduce_all((tf.convert_to_tensor(error_k) <= 1)) if not accept_step: # Retry with adjusted step size if step is rejected. dt_next = _optimal_step_size(dt, error_k, self.safety, self.ifactor, self.dfactor, order=order) return _VCABMState(y0, prev_f, prev_t, prev_t[0] + dt_next, prev_phi, order=order) # We accept the step. Evaluate f and update phi. next_t = move_to_device(next_t, p_next[0].device) next_f0 = self.func(tf.cast(next_t, self.y0[0].dtype), y_next) implicit_phi = compute_implicit_phi(phi, next_f0, order + 2) next_order = order if len(prev_t) <= 4 or order < 3: next_order = min(order + 1, 3, self.max_order) else: error_km1 = _compute_error_ratio( tuple(dt_cast * (g[order - 1] - g[order - 2]) * iphi_ for iphi_ in implicit_phi_p[order - 1]), tolerance) error_km2 = _compute_error_ratio( tuple(dt_cast * (g[order - 2] - g[order - 3]) * iphi_ for iphi_ in implicit_phi_p[order - 2]), tolerance) if min(error_km1 + error_km2) < max(error_k): next_order = order - 1 elif order < self.max_order: error_kp1 = _compute_error_ratio( tuple(dt_cast * gamma_star[order] * iphi_ for iphi_ in implicit_phi_p[order]), tolerance) if max(error_kp1) < max(error_k): next_order = order + 1 # Keep step size constant if increasing order. Else use adaptive step size. dt_next = dt if next_order > order else _optimal_step_size( dt, error_k, self.safety, self.ifactor, self.dfactor, order=order + 1) prev_f.appendleft(next_f0) prev_t.appendleft(next_t) return _VCABMState(p_next, prev_f, prev_t, next_t + dt_next, implicit_phi, order=next_order)
def grad(*grad_output, variables=None): global _arguments # t, flat_params, *ans = ctx.saved_tensors # ans = tuple(ans) # func, rtol, atol, method, options = ctx.func, ctx.rtol, ctx.atol, ctx.method, ctx.options func = _arguments.func method = _arguments.method options = _arguments.options rtol = _arguments.rtol atol = _arguments.atol print("Gradient Output : ", grad_output) print("Variables : ", variables) n_tensors = len(ans) f_params = tuple(variables) # TODO: use a tf.keras.Model and call odeint_adjoint to implement higher order derivatives. def augmented_dynamics(t, y_aug): # Dynamics of the original system augmented with # the adjoint wrt y, and an integrator wrt t and args. y, adj_y = y_aug[:n_tensors], y_aug[ n_tensors:2 * n_tensors] # Ignore adj_time and adj_params. # t = tf.get_variable('t', initializer=t) # y = tuple(tf.Variable(y_) for y_ in y) with tf.GradientTape() as tape: tape.watch(t) tape.watch(y) func_eval = func(t, y) func_eval = cast_double(func_eval) # print('y', [y_.numpy().shape for y_ in y]) # print('adj y', [a.numpy().shape for a in adj_y]) vjp_t, *vjp_y_and_params = tape.gradient( func_eval, (t, ) + y + f_params, # list(-adj_y_ for adj_y_ in adj_y), ) vjp_y = vjp_y_and_params[:n_tensors] vjp_params = vjp_y_and_params[n_tensors:] # print('vjp_y', [v.numpy().shape if v is not None else None for v in vjp_y]) # print() # autograd.grad returns None if no gradient, set to zero. vjp_t = tf.zeros_like(t, dtype=t.dtype) if vjp_t is None else vjp_t vjp_y = tuple( tf.zeros_like(y_, dtype=y_.dtype) if vjp_y_ is None else vjp_y_ for vjp_y_, y_ in zip(vjp_y, y)) vjp_params = _flatten_convert_none_to_zeros(vjp_params, f_params) if _check_len(f_params) == 0: vjp_params = tf.convert_to_tensor(0., dtype=vjp_y[0].dype) vjp_params = move_to_device(vjp_params, vjp_y[0].device) # print('vjp_t grad', vjp_t.numpy()) # print('vjp_params', [v.numpy() for v in vjp_params]) # print('vjp y grads', [v.numpy().shape for v in vjp_y]) # print("LEN FUNC EVALS : ", len(func_eval)) # print() return (*func_eval, *vjp_y, vjp_t, vjp_params) T = ans[0].shape[0] if isinstance(grad_output, tf.Tensor) or isinstance( grad_output, tf.Variable): adj_y = [grad_output[-1]] else: adj_y = tuple(grad_output_[-1] for grad_output_ in grad_output) # adj_y = tuple(grad_output_[-1] for grad_output_ in grad_output) adj_params = tf.zeros_like(flat_params, dtype=flat_params.dtype) adj_time = move_to_device(tf.convert_to_tensor(0., dtype=t.dtype), t.device) time_vjps = [] for i in range(T - 1, 0, -1): ans_i = tuple(ans_[i] for ans_ in ans) if isinstance(grad_output, tf.Tensor) or isinstance( grad_output, tf.Variable): grad_output_i = [grad_output[i]] else: grad_output_i = tuple(grad_output_[i] for grad_output_ in grad_output) func_i = func(t[i], ans_i) func_i = cast_double(func_i) if not isinstance(func_i, Iterable): func_i = [func_i] # Compute the effect of moving the current time measurement point. dLd_cur_t = sum( tf.reshape( tf.matmul(tf.reshape(func_i_, [1, -1]), tf.reshape(grad_output_i_, [-1, 1])), [1]) for func_i_, grad_output_i_ in zip(func_i, grad_output_i)) adj_time = cast_double(adj_time) adj_time = adj_time - dLd_cur_t time_vjps.append(dLd_cur_t) # Run the augmented system backwards in time. if isinstance(adj_params, Iterable): count = _numel(adj_params) if count == 0: adj_params = move_to_device( tf.convert_to_tensor(0., dtype=adj_y[0].dtype), adj_y[0].device) aug_y0 = (*ans_i, *adj_y, adj_time, adj_params) # print('ans i', [a.numpy().shape for a in ans_i]) # print('adj y', [a.numpy().shape for a in adj_y]) # print('adj time', adj_time.numpy().shape) # print('adj params', adj_params.numpy().shape) # print() aug_ans = odeint(augmented_dynamics, aug_y0, tf.convert_to_tensor([t[i], t[i - 1]]), rtol=rtol, atol=atol, method=method, options=options) # Unpack aug_ans. adj_y = aug_ans[n_tensors:2 * n_tensors] adj_time = aug_ans[2 * n_tensors] adj_params = aug_ans[2 * n_tensors + 1] adj_y = tuple(adj_y_[1] if _check_len(adj_y_) > 0 else adj_y_ for adj_y_ in adj_y) if _check_len(adj_time) > 0: adj_time = adj_time[1] if _check_len(adj_params) > 0: adj_params = adj_params[1] adj_y = tuple(adj_y_ + grad_output_[i - 1] for adj_y_, grad_output_ in zip(adj_y, grad_output)) del aug_y0, aug_ans time_vjps.append(adj_time) time_vjps = tf.concat(time_vjps[::-1], 0) print() print('adj y', len(adj_y)) print('time vjps', time_vjps.shape) print('adj params', adj_params.shape) print() # reshape the parameters back into the correct variable shapes var_flat_lens = [_numel(v, dtype=tf.int32).numpy() for v in variables] var_shapes = [v.shape for v in variables] adj_params_splits = tf.split(adj_params, var_flat_lens) adj_params_list = [ tf.reshape(p, v_shape) for p, v_shape in zip(adj_params_splits, var_shapes) ] # add the time gradient (always the first tensor in list of variables) # adj_params.insert(0, time_vjps) # adj_y_grad_vars = list(zip(adj_y, grad_output)) # time_grad_vars = list((time_vjps, t)) model_vars = list(adj_params_list) # list(zip(adj_params, variables)) grad_vars = model_vars # adj_y_grad_vars + time_grad_vars + model_vars # print('adj y grad', len(adj_y_grad_vars)) # print('time grad', len(time_grad_vars)) print('model grad', len(model_vars)) print('model grad values', [v for v in grad_vars]) print() # if len(adj_y) == 1: # adj_y = adj_y[0] return (adj_y, model_vars)