def main(): sol = dict() for method in ['dopri5', 'adams']: for tol in [1e-3, 1e-6, 1e-9]: print('======= {} | tol={:e} ======='.format(method, tol)) nfes = [] times = [] errs = [] for c in ['A', 'B', 'C', 'D', 'E']: for i in ['1', '2', '3', '4', '5']: diffeq, init, _ = getattr(detest, c + i)() t0, y0 = init() diffeq = NFEDiffEq(diffeq) if not c + i in sol: sol[c + i] = odeint(diffeq, y0, tf.stack([ t0, tf.convert_to_tensor( 20., dtype=tf.float64) ]), atol=1e-12, rtol=1e-12, method='dopri5')[1] diffeq.nfe = 0 start_time = time.time() est = odeint(diffeq, y0, tf.stack([ t0, tf.convert_to_tensor(20., dtype=tf.float64) ]), atol=tol, rtol=tol, method=method) time_spent = time.time() - start_time error = tf.sqrt(tf.reduce_mean((sol[c + i] - est[1])**2)) errs.append(error.numpy()) nfes.append(diffeq.nfe) times.append(time_spent) print('{}: NFE {} | Time {} | Err {:e}'.format( c + i, diffeq.nfe, time_spent, error.numpy())) print('Total NFE {} | Total Time {} | GeomAvg Error {:e}'.format( np.sum(nfes), np.sum(times), gmean(errs)))
def test_dopri5_adjoint_against_dopri5(self): tf.keras.backend.set_floatx('float64') tf.compat.v1.set_random_seed(0) with tf.GradientTape(persistent=True) as tape: func, y0, t_points = self.problem() tape.watch(t_points) tape.watch(y0) ys = tfdiffeq.odeint_adjoint(func, y0, t_points, method='dopri5') gradys = 0.1 * tf.random.uniform(shape=ys.shape, dtype=tf.float64) adj_y0_grad, adj_t_grad, adj_A_grad = tape.gradient( ys, [y0, t_points, func.A], output_gradients=gradys) w_grad, b_grad = tape.gradient(ys, func.unused_module.variables) self.assertIsNone(w_grad) self.assertIsNone(b_grad) with tf.GradientTape() as tape: func, y0, t_points = self.problem() tape.watch(y0) tape.watch(t_points) ys = tfdiffeq.odeint(func, y0, t_points, method='dopri5') y_grad, t_grad, a_grad = tape.gradient(ys, [y0, t_points, func.A], output_gradients=gradys) self.assertLess(max_abs(y_grad - adj_y0_grad), 3e-4) self.assertLess(max_abs(t_grad - adj_t_grad), 1e-4) self.assertLess(max_abs(a_grad - adj_A_grad), 2e-3)
def test_dopri5(self): for ode in problems.PROBLEMS.keys(): f, y0, t_points, sol = problems.construct_problem(TEST_DEVICE, ode=ode) y = tfdiffeq.odeint(f, y0, t_points, method='dopri5') with self.subTest(ode=ode): self.assertLess(rel_error(sol, y), error_tol)
def call(self, x, training=None, eval_times=None, **kwargs): """ Solves ODE starting from x. # Arguments: x: Tensor. Shape (batch_size, self.odefunc.data_dim) # Returns: Output tensor of forward pass. """ # Forward pass corresponds to solving ODE, so reset number of function # evaluations counter self.odefunc.nfe.assign(0.) if eval_times is None: integration_time = tf.cast(tf.linspace(0., 1., 2), dtype=x.dtype) else: integration_time = tf.cast(eval_times, x.dtype) if self.odefunc.augment_dim > 0: # Add augmentation aug = tf.zeros([x.shape[0], self.odefunc.augment_dim], dtype=x.dtype) # Shape (batch_size, data_dim + augment_dim) x_aug = tf.concat([x, aug], axis=-1) else: x_aug = x out = odeint(self.odefunc, x_aug, integration_time, rtol=self.tol, atol=self.tol, method=self.method, options=self.options) if eval_times is None: return out[1] # Return only final time return out
def test_adjoint(self): """ Test against dopri5 """ tf.compat.v1.set_random_seed(0) f, y0, t_points, _ = problems.construct_problem(TEST_DEVICE) y0 = tf.cast(y0, tf.float64) t_points = tf.cast(t_points, tf.float64) func = lambda y0, t_points: tfdiffeq.odeint(f, y0, t_points, method='dopri5') with tf.GradientTape() as tape: tape.watch(t_points) ys = func(y0, t_points) reg_t_grad, reg_a_grad, reg_b_grad = tape.gradient(ys, [t_points, f.a, f.b]) f, y0, t_points, _ = problems.construct_problem(TEST_DEVICE) y0 = tf.cast(y0, tf.float64) t_points = tf.cast(t_points, tf.float64) y0 = (y0,) func = lambda y0, t_points: tfdiffeq.odeint_adjoint(f, y0, t_points, method='dopri5') with tf.GradientTape() as tape: tape.watch(t_points) ys = func(y0, t_points) grads = tape.gradient(ys, [t_points, f.a, f.b]) adj_t_grad, adj_a_grad, adj_b_grad = grads self.assertLess(max_abs(reg_t_grad - adj_t_grad), 1.2e-7) self.assertLess(max_abs(reg_a_grad - adj_a_grad), 1.2e-7) self.assertLess(max_abs(reg_b_grad - adj_b_grad), 1.2e-7)
def test_adams_gradient(self): f, y0, t_points, sol = construct_problem(TEST_DEVICE) tuple_f = lambda t, y: (f(t, y[0]), f(t, y[1])) for i in range(2): func = lambda y0, t_points: tfdiffeq.odeint(tuple_f, (y0, y0), t_points, method='adams')[i] self.assertTrue(gradcheck(func, (y0, t_points)))
def call(self, x): # self.integration_time = tf.cast(self.integration_time, x.dtype) out = odeint(self.odefunc, x, self.integration_time, rtol=args.tol, atol=args.tol, method=args.method) return tf.cast(out[1], tf.float32) # necessary cast
def test_adams(self): f, y0, t_points, sol = construct_problem(TEST_DEVICE) tuple_f = lambda t, y: (f(t, y[0]), f(t, y[1])) tuple_y0 = (y0, y0) tuple_y = tfdiffeq.odeint(tuple_f, tuple_y0, t_points, method='adams') max_error0 = tf.reduce_max(sol - tuple_y[0]) max_error1 = tf.reduce_max(sol - tuple_y[1]) self.assertLess(max_error0, eps) self.assertLess(max_error1, eps)
def test_adaptive_heun(self): for ode in problems.PROBLEMS.keys(): if ode == 'sine': # Sine test never finishes. continue f, y0, t_points, sol = problems.construct_problem(TEST_DEVICE, ode=ode) y = tfdiffeq.odeint(f, y0, t_points, method='adaptive_heun') with self.subTest(ode=ode): self.assertLess(rel_error(sol, y), error_tol)
def step(self, dt=0.01, n_steps=10, *args, **kwargs): """ Steps the system forward by dt. Uses tfdiffeq.odeint for integration. # Arguments: dt: Float - time step n_steps: Int - number of sub-steps to return values for. The integrator may decide to use more steps to achieve the set tolerance. # Returns: x: tf.Tensor, shape=(8,) - new state of the system """ t = tf.linspace(0., dt, n_steps) self.x = odeint(self, self.x, t, *args, **kwargs) return self.x
def call(self, x, training=None, eval_times=None, **kwargs): """ Solves ODE starting from x. # Arguments: x: Tensor. Shape (batch_size, self.odefunc.data_dim) eval_times: None or tf.Tensor. If None, returns solution of ODE at final time t=1. If tf.Tensor then returns full ODE trajectory evaluated at points in eval_times. # Returns: Output tensor of forward pass. """ # Forward pass corresponds to solving ODE, so reset number of function # evaluations counter self.odefunc.nfe.assign(0.) if eval_times is None: integration_time = tf.cast(tf.linspace(0., 1., 2), dtype=x.dtype) else: integration_time = tf.cast(eval_times, x.dtype) if self.odefunc.augment_dim > 0: if self.is_conv: # Add augmentation batch_size, height, width, channels = x.shape if self.channel_axis == 1: aug = tf.zeros([batch_size, self.odefunc.augment_dim, height, width], dtype=x.dtype) else: aug = tf.zeros([batch_size, height, width, self.odefunc.augment_dim], dtype=x.dtype) # Shape (batch_size, channels + augment_dim, height, width) x_aug = tf.concat([x, aug], axis=self.channel_axis) else: # Add augmentation aug = tf.zeros([x.shape[0], self.odefunc.augment_dim], dtype=x.dtype) # Shape (batch_size, data_dim + augment_dim) x_aug = tf.concat([x, aug], axis=-1) else: x_aug = x out = odeint(self.odefunc, x_aug, integration_time, rtol=self.tol, atol=self.tol, method=self.method, options=self.options) if eval_times is None: return out[1] # Return only final time return out
def test_explicit_adams(self): f, y0, t_points, sol = problems.construct_problem(TEST_DEVICE) y = tfdiffeq.odeint(f, y0, t_points, method='explicit_adams') self.assertLess(rel_error(sol, y), error_tol)
def test_dopri5(self): f, y0, t_points, sol = problems.construct_problem(TEST_DEVICE, reverse=True) y = tfdiffeq.odeint(f, y0, t_points[0:1], method='dopri5') self.assertLess(max_abs(sol[0] - y), error_tol)
checkpoint_manager.restore_or_initialize() if latest_checkpoint is not None: print('Loaded ckpt from {}'.format(ckpt_path)) if args.mode == 'train': try: for itr in range(1, args.niters + 1): with tf.GradientTape() as tape: x, logp_diff_t1 = get_batch(args.num_samples) z_t, logp_diff_t = odeint( func, (x, logp_diff_t1), tf.convert_to_tensor([t1, t0], dtype=float_dtype), atol=1e-5, rtol=1e-5, method='dopri5', ) z_t0, logp_diff_t0 = z_t[-1], logp_diff_t[-1] # Float 32 required for log_prob() z_t0 = tf.cast(z_t0, tf.float32) logp_diff_t0 = tf.cast(logp_diff_t0, tf.float32) logp_x = p_z0.log_prob(z_t0) - tf.reshape( logp_diff_t0, [-1]) # Recast logp_x = tf.cast(logp_x, float_dtype)
def test_adams(self): f, y0, t_points, _ = problems.construct_problem(TEST_DEVICE) func = lambda y0, t_points: tfdiffeq.odeint(f, y0, t_points, method='adams') self.assertTrue(gradcheck(func, (y0, t_points)))
args = parser.parse_args() device = 'gpu:' + str(args.gpu) if tf.test.is_gpu_available() else 'cpu:0' true_y0 = tf.convert_to_tensor(1, dtype=tf.float64) time = np.linspace(0, 1., num=args.data_size) t = tf.convert_to_tensor(time, dtype=tf.float32) def true_val(t, y): return 0.2 * np.exp(t) * np.exp(5 * y) + 0.6 class Lambda(tf.keras.Model): def call(self, t, y): dydt = 5 * y - 3 return dydt real_y = [true_val(t, true_y0.numpy()) for t in time] pred_y = odeint(Lambda(), true_y0, t, method=args.method) mse = np.mean(np.square(real_y - pred_y.numpy())) print('MSE : ', mse) print("Number of solutions : ", pred_y.shape) plt.plot(time, real_y, label='real') plt.plot(time, pred_y.numpy(), label='pred') plt.legend() plt.show()
### ----- Predict using trained model --------------- if scale_time == True: times_predict = times_predict/scale_time if adjoint == True: predicted_states = adjoint_odeint(model, tf.expand_dims(init_state, axis=0), tf.convert_to_tensor(times_predict), method=solver) predicted_states = tf.squeeze(predicted_states) if augmented == True: predicted_states = np.delete(predicted_states,slice(state_len,state_len+aug_dims),axis=1) elif adjoint == False: predicted_states = odeint(model, tf.expand_dims(init_state, axis=0), tf.convert_to_tensor(times_predict), method=solver) predicted_states = tf.squeeze(predicted_states) if augmented == True: predicted_states = np.delete(predicted_states,slice(state_len,state_len+aug_dims),axis=1) ### ---- Post-process predicted states --------------- if scale_states == True: inverse_scaler = lambda z: ((z + 1)*(max_g - min_g)/2 + min_g) predicted_states = inverse_scaler(predicted_states) true_state_array = inverse_scaler(true_state_array) #predicted_states = scale_mm.inverse_transform(predicted_states) if scale_time == True: times_predict = times_predict*scale_time
def visualize(model, x_val, PLOT_DIR, TIME_OF_RUN, args, ode_model=True, epoch=0, is_mdn=False): """Visualize a tf.keras.Model for an aircraft model. # Arguments: model: A Keras model, that accepts t and x when called x_val: np.ndarray, shape=(1, samples_per_series, 4) or (samples_per_series, 4) The reference time series, against which the model will be compared PLOT_DIR: Directory to plot in TIME_OF_RUN: Time at which the run began ode_model: whether the model outputs the derivative of the current step (True), or the value of the next step (False) args: input arguments from main script """ x_val = x_val.reshape(2, -1, 4) dt = 0.1 t = tf.linspace(0., 100., int(100. / dt) + 1) # Compute the predicted trajectories if ode_model: x0 = tf.convert_to_tensor(x_val[:, 0]) x_t = odeint(model, x0, t, rtol=1e-5, atol=1e-5).numpy() x_t_extrap = x_t[:, 0] x_t_interp = x_t[:, 1] else: # LSTM model x_t_extrap = np.zeros_like(x_val[0]) x_t_extrap[0] = x_val[0, 0] x_t_interp = np.zeros_like(x_val[1]) x_t_interp[0] = x_val[1, 0] # Always injects the entire time series because keras is slow when using # varying series lengths and the future timesteps don't affect the predictions # before it anyways. for i in range(1, len(t)): x_t_extrap[i:i + 1] = model(0., np.expand_dims(x_t_extrap, axis=0))[0, i - 1:i] x_t_interp[i:i + 1] = model(0., np.expand_dims(x_t_interp, axis=0))[0, i - 1:i] x_t = np.stack([x_t_extrap, x_t_interp], axis=0) # Plot the generated trajectories fig = plt.figure(figsize=(12, 8), facecolor='white') ax_traj = fig.add_subplot(231, frameon=False) ax_phase = fig.add_subplot(232, frameon=False) ax_vecfield = fig.add_subplot(233, frameon=False) ax_vec_error_abs = fig.add_subplot(234, frameon=False) ax_vec_error_rel = fig.add_subplot(235, frameon=False) ax_3d = fig.add_subplot(236, projection='3d') ax_traj.cla() ax_traj.set_title('Trajectories') ax_traj.set_xlabel('t') ax_traj.set_ylabel('V,gamma') for i in range(4): ax_traj.plot(t.numpy(), x_val[0, :, i], 'g-') ax_traj.plot(t.numpy(), x_t[0, :, i], 'b--') ax_traj.set_xlim(min(t.numpy()), max(t.numpy())) ax_traj.set_ylim(-2, 2) ax_traj.legend() ax_phase.cla() ax_phase.set_title('Phase Portrait phugoid') ax_phase.set_xlabel('V') ax_phase.set_ylabel('gamma') ax_phase.plot(x_val[0, :, 0], x_val[0, :, 1], 'g-') ax_phase.plot(x_t[0, :, 0], x_t[0, :, 1], 'b--') ax_phase.plot(x_val[1, :, 0], x_val[1, :, 1], 'g-') ax_phase.plot(x_t[1, :, 0], x_t[1, :, 1], 'b--') ax_phase.set_xlim(-6, 6) ax_phase.set_ylim(-2, 2) ax_vecfield.cla() ax_vecfield.set_title('Learned Vector Field') ax_vecfield.set_xlabel('V') ax_vecfield.set_ylabel('gamma') steps = 61 y, x = np.mgrid[-6:6:complex(0, steps), -6:6:complex(0, steps)] zeros = tf.zeros_like(x) input_grid = np.stack([x, y, zeros, zeros], -1) ref_func = Lambda() dydt_ref = ref_func(0., input_grid.reshape(steps * steps, 4)).numpy() mag_ref = 1e-8 + np.linalg.norm(dydt_ref, axis=-1).reshape(steps, steps) dydt_ref = dydt_ref.reshape(steps, steps, 4) if ode_model: # is Dense-Net or NODE-Net or NODE-e2e dydt = model(0., input_grid.reshape(steps * steps, 4)).numpy() else: # is LSTM # Compute artificial x_dot by numerically diffentiating: # x_dot \approx (x_{t+1}-x_t)/d yt_1 = model(0., input_grid.reshape(steps * steps, 1, 4))[:, 0] dydt = (np.array(yt_1) - input_grid.reshape(steps * steps, 4)) / dt dydt_abs = dydt.reshape(steps, steps, 4) dydt_unit = dydt_abs / np.linalg.norm(dydt_abs, axis=-1, keepdims=True) ax_vecfield.streamplot(x, y, dydt_unit[:, :, 0], dydt_unit[:, :, 1], color="black") ax_vecfield.set_xlim(-4, 4) ax_vecfield.set_ylim(-2, 2) ax_vec_error_abs.cla() ax_vec_error_abs.set_title('Abs. error of V\', gamma\'') ax_vec_error_abs.set_xlabel('V') ax_vec_error_abs.set_ylabel('gamma') abs_dif = np.clip(np.linalg.norm(dydt_abs - dydt_ref, axis=-1), 0., 3.) c1 = ax_vec_error_abs.contourf(x, y, abs_dif, 100) plt.colorbar(c1, ax=ax_vec_error_abs) ax_vec_error_abs.set_xlim(-6, 6) ax_vec_error_abs.set_ylim(-6, 6) ax_vec_error_rel.cla() ax_vec_error_rel.set_title('Rel. error of V\', gamma\'') ax_vec_error_rel.set_xlabel('V') ax_vec_error_rel.set_ylabel('gamma') rel_dif = np.clip(abs_dif / mag_ref, 0., 1.) c2 = ax_vec_error_rel.contourf(x, y, rel_dif, 100) plt.colorbar(c2, ax=ax_vec_error_rel) ax_vec_error_rel.set_xlim(-6, 6) ax_vec_error_rel.set_ylim(-6, 6) ax_3d.cla() ax_3d.set_title('3D Trajectory') ax_3d.set_xlabel('V') ax_3d.set_ylabel('gamma') ax_3d.set_zlabel('alpha') ax_3d.scatter(x_val[0, :, 0], x_val[0, :, 1], x_val[0, :, 2], c='g', s=4, marker='^') ax_3d.scatter(x_t[0, :, 0], x_t[0, :, 1], x_t[0, :, 2], c='b', s=4, marker='o') ax_3d.view_init(elev=40., azim=60.) fig.tight_layout() plt.savefig(PLOT_DIR + '/{:03d}'.format(epoch)) plt.close() # Compute Metrics phase_error_extrap_lp, phase_error_extrap_sp = relative_phase_error( x_t[0], x_val[0]) traj_error_extrap = trajectory_error(x_t[0], x_val[0]) phase_error_interp_lp, phase_error_interp_sp = relative_phase_error( x_t[1], x_val[1]) traj_error_interp = trajectory_error(x_t[1], x_val[1]) wall_time = (datetime.datetime.now() - datetime.datetime.strptime( TIME_OF_RUN, "%Y%m%d-%H%M%S")).total_seconds() string = "{},{},{:.7f},{:.7f},{:.7f},{:.7f},{:.7f},{:.7f}\n".format( wall_time, epoch, phase_error_interp_lp, phase_error_interp_sp, phase_error_extrap_lp, phase_error_extrap_sp, traj_error_interp, traj_error_extrap) file_path = (PLOT_DIR + TIME_OF_RUN + "results" + str(args.lr) + str(args.dataset_size) + str(args.batch_size) + ".csv") if not os.path.isfile(file_path): title_string = ("wall_time,epoch," + "phase_error_interp_lp,phase_error_interp_sp," + "phase_error_extrap_lp,phase_error_extrap_sp," + "traj_err_interp, traj_err_extrap\n") fd = open(file_path, 'a') fd.write(title_string) fd.close() fd = open(file_path, 'a') fd.write(string) fd.close() # Print Jacobian if ode_model: np.set_printoptions(suppress=True, precision=4, linewidth=150) # The first Jacobian is averaged over 100 randomly sampled points from U(-1, 1) jac = tf.zeros((4, 4)) for i in range(100): with tf.GradientTape(persistent=True) as g: x = (2 * tf.random.uniform((1, 4)) - 1) g.watch(x) y = model(0, x) jac = jac + g.jacobian(y, x)[0, :, 0] print(jac.numpy() / 100) with tf.GradientTape(persistent=True) as g: x = tf.zeros([1, 4]) g.watch(x) y = model(0, x) print(g.jacobian(y, x)[0, :, 0])
t = tf.convert_to_tensor(t_n, dtype=tf.float32) true_A = tf.convert_to_tensor([[1, -0.2], [-0.2, 1]], dtype=tf.float64) class Lambda(tf.keras.Model): def call(self, t, y): dydt = tf.matmul(y, true_A) return dydt with tf.device(device): t1 = time.time() pred_y = odeint(Lambda(), true_y0, t, rtol=args.rtol, atol=args.atol, method=args.method) t2 = time.time() print("Number of solutions : ", pred_y.shape) print("Time taken : ", t2 - t1) pred_y = pred_y.numpy() plt.plot(t_n, pred_y[:, 0, 0], t_n, pred_y[:, 0, 1], 'r-', label='trajectory') # plt.plot(time, pred_y.numpy(), 'b--', label='y') plt.legend() plt.xlabel('time') plt.ylabel('magnitude') plt.show()
def grad(*grad_output, variables=None): global _arguments flat_params = _flatten(variables) func = _arguments.func adjoint_method = _arguments.adjoint_method adjoint_rtol = _arguments.rtol adjoint_atol = _arguments.atol adjoint_options = _arguments.adjoint_options n_tensors = len(ans) f_params = tuple(variables) # TODO: use a tf.keras.Model and call odeint_adjoint to implement higher order derivatives. def augmented_dynamics(t, y_aug): # Dynamics of the original system augmented with # the adjoint wrt y, and an integrator wrt t and args. y, adj_y = y_aug[:n_tensors], y_aug[n_tensors:2 * n_tensors] # Ignore adj_time and adj_params. with tf.GradientTape() as tape: tape.watch(t) tape.watch(y) func_eval = func(t, y) func_eval = tf.convert_to_tensor(func_eval) gradys = -tf.stack(adj_y) if type(func_eval) in [list, tuple]: for eval_ix in range(len(func_eval)): if len(gradys[eval_ix].shape) < len(func_eval[eval_ix].shape): gradys[eval_ix] = tf.expand_dims(gradys[eval_ix], axis=0) else: if len(gradys.shape) < len(func_eval.shape): gradys = tf.expand_dims(gradys, axis=0) vjp_t, *vjp_y_and_params = tape.gradient( func_eval, (t,) + y + f_params, output_gradients=gradys, unconnected_gradients=tf.UnconnectedGradients.ZERO ) vjp_y = vjp_y_and_params[:n_tensors] vjp_params = vjp_y_and_params[n_tensors:] vjp_params = _flatten(vjp_params) if _check_len(f_params) == 0: vjp_params = tf.convert_to_tensor(0., dtype=vjp_y[0].dype) vjp_params = move_to_device(vjp_params, vjp_y[0].device) return (*func_eval, *vjp_y, vjp_t, vjp_params) T = ans[0].shape[0] if isinstance(grad_output, (tf.Tensor, tf.Variable)): adj_y = [grad_output[-1]] else: adj_y = tuple(grad_output_[-1] for grad_output_ in grad_output) adj_params = tf.zeros_like(flat_params, dtype=flat_params.dtype) adj_time = move_to_device(tf.convert_to_tensor(0., dtype=t.dtype), t.device) time_vjps = [] for i in range(T - 1, 0, -1): ans_i = tuple(ans_[i] for ans_ in ans) if isinstance(grad_output, (tf.Tensor, tf.Variable)): grad_output_i = [grad_output[i]] else: grad_output_i = tuple(grad_output_[i] for grad_output_ in grad_output) func_i = func(t[i], ans_i) if not isinstance(func_i, Iterable): func_i = [func_i] # Compute the effect of moving the current time measurement point. dLd_cur_t = sum( tf.reshape(tf.matmul(tf.reshape(func_i_, [1, -1]), tf.reshape(grad_output_i_, [-1, 1])), [1]) for func_i_, grad_output_i_ in zip(func_i, grad_output_i) ) adj_time = adj_time - dLd_cur_t time_vjps.append(dLd_cur_t) # Run the augmented system backwards in time. if isinstance(adj_params, Iterable): if _numel(adj_params) == 0: adj_params = move_to_device(tf.convert_to_tensor(0., dtype=adj_y[0].dtype), adj_y[0].device) aug_y0 = (*ans_i, *adj_y, adj_time, adj_params) aug_ans = odeint( augmented_dynamics, aug_y0, tf.convert_to_tensor([t[i], t[i - 1]]), rtol=adjoint_rtol, atol=adjoint_atol, method=adjoint_method, options=adjoint_options ) # Unpack aug_ans. adj_y = aug_ans[n_tensors:2 * n_tensors] adj_time = aug_ans[2 * n_tensors] adj_params = aug_ans[2 * n_tensors + 1] adj_y = tuple(adj_y_[1] if _check_len(adj_y_) > 0 else adj_y_ for adj_y_ in adj_y) if _check_len(adj_time) > 0: adj_time = adj_time[1] if _check_len(adj_params) > 0: adj_params = adj_params[1] adj_y = tuple(adj_y_ + grad_output_[i - 1] for adj_y_, grad_output_ in zip(adj_y, grad_output)) del aug_y0, aug_ans time_vjps.append(adj_time) time_vjps = tf.concat(time_vjps[::-1], 0) # reshape the parameters back into the correct variable shapes var_flat_lens = [_numel(v, dtype=tf.int32).numpy() for v in variables] var_shapes = [v.shape for v in variables] adj_params_splits = tf.split(adj_params, var_flat_lens) adj_params_list = [tf.reshape(p, v_shape) for p, v_shape in zip(adj_params_splits, var_shapes)] return (*adj_y, time_vjps), adj_params_list
if __name__ == '__main__': end = time.time() time_meter = RunningAverageMeter(0.97) loss_meter = RunningAverageMeter(0.97) with tf.device(device): func = ODEFunc() lr = tf.Variable(args.lr) optimizer = tf.keras.optimizers.Adam(lr, clipvalue=0.5) for itr in range(1, args.niters + 1): with tf.GradientTape() as tape: batch_x0, batch_t, batch_x = get_batch() pred_x = odeint(func, batch_x0, batch_t, method=args.method) # (T, B, D) ex_loss = tf.reduce_sum(tf.math.square(pred_x - batch_x), axis=-1) loss = tf.reduce_mean(ex_loss) weights = [ v for v in func.trainable_variables if 'bias' not in v.name ] l2_loss = tf.add_n( [tf.reduce_sum(tf.math.square(v)) for v in weights]) * 1e-6 loss = loss + l2_loss nfe = func.nfe.numpy() func.nfe.assign(0.) grads = tape.gradient(loss, func.trainable_variables) nbe = func.nfe.numpy() func.nfe.assign(0.)
def visualize(model, x_val, PLOT_DIR, TIME_OF_RUN, args, ode_model=True, latent=False, epoch=0, is_mdn=False): """Visualize a tf.keras.Model for a single pendulum. # Arguments: model: A Keras model, that accepts t and x when called x_val: np.ndarray, shape=(1, samples_per_series, 2) or (samples_per_series, 2) The reference time series, against which the model will be compared PLOT_DIR: Directory to plot in TIME_OF_RUN: Time at which the run began ode_model: whether the model outputs the derivative of the current step (True), or the value of the next step (False) args: input arguments from main script """ x_val = x_val.reshape(2, -1, 2) dt = 0.01 t = tf.linspace(0., 10., int(10. / dt) + 1) # Compute the predicted trajectories if ode_model: x0_extrap = tf.stack([x_val[0, 0]]) x_t_extrap = odeint(model, x0_extrap, t, rtol=1e-5, atol=1e-5).numpy()[:, 0] x0_interp = tf.stack([x_val[1, 0]]) x_t_interp = odeint(model, x0_interp, t, rtol=1e-5, atol=1e-5).numpy()[:, 0] else: # LSTM model x_t_extrap = np.zeros_like(x_val[0]) x_t_extrap[0] = x_val[0, 0] x_t_interp = np.zeros_like(x_val[1]) x_t_interp[0] = x_val[1, 0] # Always injects the entire time series because keras is slow when using # varying series lengths and the future timesteps don't affect the predictions # before it anyways. if is_mdn: import mdn for i in range(1, len(t)): pred_extrap = model(0., np.expand_dims(x_t_extrap, axis=0))[0, i - 1:i] x_t_extrap[i:i + 1] = mdn.sample_from_output( pred_extrap.numpy()[0], 2, 5, temp=1.) pred_interp = model(0., np.expand_dims(x_t_interp, axis=0))[0, i - 1:i] x_t_interp[i:i + 1] = mdn.sample_from_output( pred_interp.numpy()[0], 2, 5, temp=1.) else: for i in range(1, len(t)): x_t_extrap[i:i + 1] = model(0., np.expand_dims(x_t_extrap, axis=0))[0, i - 1:i] x_t_interp[i:i + 1] = model(0., np.expand_dims(x_t_interp, axis=0))[0, i - 1:i] x_t = np.stack([x_t_extrap, x_t_interp], axis=0) # Plot the generated trajectories fig = plt.figure(figsize=(12, 8), facecolor='white') ax_traj = fig.add_subplot(231, frameon=False) ax_phase = fig.add_subplot(232, frameon=False) ax_vecfield = fig.add_subplot(233, frameon=False) ax_vec_error_abs = fig.add_subplot(234, frameon=False) ax_vec_error_rel = fig.add_subplot(235, frameon=False) ax_energy = fig.add_subplot(236, frameon=False) ax_traj.cla() ax_traj.set_title('Trajectories') ax_traj.set_xlabel('t') ax_traj.set_ylabel('x,y') ax_traj.plot(t.numpy(), x_val[0, :, 0], t.numpy(), x_val[0, :, 1], 'g-') ax_traj.plot(t.numpy(), x_t[0, :, 0], '--', t.numpy(), x_t[0, :, 1], 'b--') ax_traj.set_xlim(min(t.numpy()), max(t.numpy())) ax_traj.set_ylim(-6, 6) ax_traj.legend() ax_phase.cla() ax_phase.set_title('Phase Portrait') ax_phase.set_xlabel('x') ax_phase.set_ylabel('x_dt') ax_phase.plot(x_val[0, :, 0], x_val[0, :, 1], 'g--') ax_phase.plot(x_t[0, :, 0], x_t[0, :, 1], 'b--') ax_phase.plot(x_val[1, :, 0], x_val[1, :, 1], 'g--') ax_phase.plot(x_t[1, :, 0], x_t[1, :, 1], 'b--') ax_phase.set_xlim(-6, 6) ax_phase.set_ylim(-6, 6) ax_vecfield.cla() ax_vecfield.set_title('Learned Vector Field') ax_vecfield.set_xlabel('x') ax_vecfield.set_ylabel('x_dt') steps = 61 y, x = np.mgrid[-6:6:complex(0, steps), -6:6:complex(0, steps)] ref_func = Lambda() dydt_ref = ref_func(0., np.stack([x, y], -1).reshape(steps * steps, 2)).numpy() mag_ref = 1e-8 + np.linalg.norm(dydt_ref, axis=-1).reshape(steps, steps) dydt_ref = dydt_ref.reshape(steps, steps, 2) if ode_model: # is Dense-Net or NODE-Net or NODE-e2e dydt = model(0., np.stack([x, y], -1).reshape(steps * steps, 2)).numpy() else: # is LSTM # Compute artificial x_dot by numerically diffentiating: # x_dot \approx (x_{t+1}-x_t)/dt yt_1 = model(0., np.stack([x, y], -1).reshape(steps * steps, 1, 2))[:, 0] if is_mdn: # have to sample from output Gaussians yt_1 = np.apply_along_axis(mdn.sample_from_output, 1, yt_1.numpy(), 2, 5, temp=.1)[:, 0] dydt = (np.array(yt_1) - np.stack([x, y], -1).reshape(steps * steps, 2)) / dt dydt_abs = dydt.reshape(steps, steps, 2) dydt_unit = dydt_abs / np.linalg.norm(dydt_abs, axis=-1, keepdims=True) # make unit vector ax_vecfield.streamplot(x, y, dydt_unit[:, :, 0], dydt_unit[:, :, 1], color="black") ax_vecfield.set_xlim(-6, 6) ax_vecfield.set_ylim(-6, 6) ax_vec_error_abs.cla() ax_vec_error_abs.set_title('Abs. error of xdot') ax_vec_error_abs.set_xlabel('x') ax_vec_error_abs.set_ylabel('x_dt') abs_dif = np.clip(np.linalg.norm(dydt_abs - dydt_ref, axis=-1), 0., 3.) c1 = ax_vec_error_abs.contourf(x, y, abs_dif, 100) plt.colorbar(c1, ax=ax_vec_error_abs) ax_vec_error_abs.set_xlim(-6, 6) ax_vec_error_abs.set_ylim(-6, 6) ax_vec_error_rel.cla() ax_vec_error_rel.set_title('Rel. error of xdot') ax_vec_error_rel.set_xlabel('x') ax_vec_error_rel.set_ylabel('x_dt') rel_dif = np.clip(abs_dif / mag_ref, 0., 1.) c2 = ax_vec_error_rel.contourf(x, y, rel_dif, 100) plt.colorbar(c2, ax=ax_vec_error_rel) ax_vec_error_rel.set_xlim(-6, 6) ax_vec_error_rel.set_ylim(-6, 6) ax_energy.cla() ax_energy.set_title('Total Energy') ax_energy.set_xlabel('t') ax_energy.plot( np.arange(1001) / 100.1, np.array([total_energy(x_) for x_ in x_t_interp])) fig.tight_layout() plt.savefig(PLOT_DIR + '/{:03d}'.format(epoch)) plt.close() # Compute Metrics energy_drift_extrap = relative_energy_drift(x_t[0], x_val[0]) phase_error_extrap = relative_phase_error(x_t[0], x_val[0]) traj_error_extrap = trajectory_error(x_t[0], x_val[0]) energy_drift_interp = relative_energy_drift(x_t[1], x_val[1]) phase_error_interp = relative_phase_error(x_t[1], x_val[1]) traj_error_interp = trajectory_error(x_t[1], x_val[1]) wall_time = (datetime.datetime.now() - datetime.datetime.strptime( TIME_OF_RUN, "%Y%m%d-%H%M%S")).total_seconds() string = "{},{},{},{},{},{},{},{}\n".format( wall_time, epoch, energy_drift_interp, energy_drift_extrap, phase_error_interp, phase_error_extrap, traj_error_interp, traj_error_extrap) file_path = (PLOT_DIR + TIME_OF_RUN + "results" + str(args.lr) + str(args.dataset_size) + str(args.batch_size) + ".csv") if not os.path.isfile(file_path): title_string = ( "wall_time,epoch,energy_drift_interp,energy_drift_extrap, phase_error_interp," + "phase_error_extrap, traj_err_interp, traj_err_extrap\n") fd = open(file_path, 'a') fd.write(title_string) fd.close() fd = open(file_path, 'a') fd.write(string) fd.close() # Print Jacobian if ode_model: np.set_printoptions(suppress=True, precision=4, linewidth=150) # The first Jacobian is averaged over 100 randomly sampled points from U(-1, 1) jac = tf.zeros((2, 2)) for i in range(100): with tf.GradientTape(persistent=True) as g: x = (2 * tf.random.uniform((1, 2)) - 1) g.watch(x) y = model(0, x) jac = jac + g.jacobian(y, x)[0, :, 0] print(jac.numpy() / 100) with tf.GradientTape(persistent=True) as g: x = tf.zeros([1, 2]) g.watch(x) y = model(0, x) print(g.jacobian(y, x)[0, :, 0])
def test_rk4(self): f, y0, t_points, sol = problems.construct_problem(TEST_DEVICE, reverse=True) y = tfdiffeq.odeint(f, y0, t_points, method='rk4') self.assertLess(rel_error(sol, y), error_tol)
else: tf.keras.backend.set_floatx('float64') x_0 = tf.constant(1., dtype=dtype) # not important for Gradient a = tf.constant(2., dtype=dtype) b = tf.constant(2., dtype=dtype) T = tf.constant(2., dtype=dtype) t = tf.cast(tf.linspace(0., T, 2), dtype) odemodel = ODE(a, b, dtype) for rtol in np.logspace(-13, 0, 14)[::-1]: print('rtol:', rtol) # Run forward and backward passes, while tracking the time with tf.device('/gpu:0'): t0 = time.time() with tf.GradientTape() as g: y_sol = odeint(odemodel, x_0, t, rtol=rtol, atol=1e-10)[-1] t1 = time.time() dYdX_backprop = g.gradient(y_sol, odemodel.b).numpy() t2 = time.time() with tf.GradientTape() as g: y_sol_adj = odeint_adjoint(odemodel, x_0, t, rtol=rtol, atol=1e-10)[-1] t3 = time.time() dYdX_adjoint = g.gradient(y_sol_adj, odemodel.b).numpy() t4 = time.time() dYdX_exact = exact_derivative(a, b, T).numpy() rel_err_adj = abs(dYdX_adjoint-dYdX_exact)/dYdX_exact rel_err_bp = abs(dYdX_backprop-dYdX_exact)/dYdX_exact print('Adjoint:', rel_err_adj, dtype) print('Backprop:', rel_err_bp, dtype) fd = open(file_path, 'a') fd.write('{},{},adjoint,{},{},{},{},{}\n'.format(dtype,
if __name__ == '__main__': end = time.time() time_meter = RunningAverageMeter(0.97) loss_meter = RunningAverageMeter(0.97) with tf.device(device): func = ODEFunc() lr = tf.Variable(args.lr) optimizer = tf.keras.optimizers.Adam(lr, clipvalue=0.5) for itr in range(1, args.niters + 1): with tf.GradientTape() as tape: batch_x0, batch_t, batch_x = get_batch() pred_x = odeint(func, batch_x0, batch_t, method=args.method) # (T, B, D) ex_loss = tf.reduce_sum(tf.math.square(pred_x - batch_x), axis=-1) loss = tf.reduce_mean(ex_loss) weights = [ v for v in func.trainable_variables if 'bias' not in v.name ] l2_loss = tf.add_n( [tf.reduce_sum(tf.math.square(v)) for v in weights]) * 1e-5 loss = loss + l2_loss nfe = func.nfe.numpy() func.nfe.assign(0.) grads = tape.gradient(loss, func.trainable_variables) nbe = func.nfe.numpy() func.nfe.assign(0.)
def visualize(model, x_val, PLOT_DIR, TIME_OF_RUN, args, ode_model=True, latent=False, epoch=0): """Visualize a tf.keras.Model for a single pendulum. # Arguments: model: a Keras model x_val: np.ndarray, shape=(1, samples_per_series, 2) or (samples_per_series, 2) The reference time series, against which the model will be compared PLOT_DIR: dir to plot in TIME_OF_RUN: time of the run ode_model: whether the model outputs the derivative of the current step args: input arguments from main script """ x_val = x_val.reshape(-1, 2) dt = 0.01 t = tf.linspace(0., 10., int(10. / dt) + 1) # Compute the predicted trajectories if ode_model: x0 = tf.stack([[1.5, .5]]) x_t = odeint(model, x0, t, rtol=1e-5, atol=1e-5).numpy()[:, 0] else: # is LSTM x_t = np.zeros_like(x_val[0]) x_t[0] = x_val[0] for i in range(1, len(t)): x_t[1:i + 1] = model(0., np.expand_dims(x_t, axis=0))[0, :i] fig = plt.figure(figsize=(12, 8), facecolor='white') ax_traj = fig.add_subplot(231, frameon=False) ax_phase = fig.add_subplot(232, frameon=False) ax_vecfield = fig.add_subplot(233, frameon=False) ax_vec_error_abs = fig.add_subplot(234, frameon=False) ax_vec_error_rel = fig.add_subplot(235, frameon=False) ax_energy = fig.add_subplot(236, frameon=False) ax_traj.cla() ax_traj.set_title('Trajectories') ax_traj.set_xlabel('t') ax_traj.set_ylabel('x,y') ax_traj.plot(t.numpy(), x_val[:, 0], t.numpy(), x_val[:, 1], 'g-') ax_traj.plot(t.numpy(), x_t[:, 0], '--', t.numpy(), x_t[:, 1], 'b--') ax_traj.set_xlim(min(t.numpy()), max(t.numpy())) ax_traj.set_ylim(-6, 6) ax_traj.legend() ax_phase.cla() ax_phase.set_title('Phase Portrait') ax_phase.set_xlabel('theta') ax_phase.set_ylabel('theta_dt') ax_phase.plot(x_val[:, 0], x_val[:, 1], 'g--') ax_phase.plot(x_t[:, 0], x_t[:, 1], 'b--') ax_phase.set_xlim(-6, 6) ax_phase.set_ylim(-6, 6) ax_vecfield.cla() ax_vecfield.set_title('Learned Vector Field') ax_vecfield.set_xlabel('theta') ax_vecfield.set_ylabel('theta_dt') steps = 61 y, x = np.mgrid[-6:6:complex(0, steps), -6:6:complex(0, steps)] ref_func = Lambda() dydt_ref = ref_func(0., np.stack([x, y], -1).reshape(steps * steps, 2)).numpy() mag_ref = 1e-8 + np.linalg.norm(dydt_ref, axis=-1).reshape(steps, steps) dydt_ref = dydt_ref.reshape(steps, steps, 2) if ode_model: # is Dense-Net or NODE-Net or NODE-e2e dydt = model(0., np.stack([x, y], -1).reshape(steps * steps, 2)).numpy() else: # is LSTM # Compute artificial x_dot by numerically diffentiating: # x_dot \approx (x_{t+1}-x_t)/dt yt_1 = model(0., np.stack([x, y], -1).reshape(steps * steps, 1, 2))[:, 0] dydt = (np.array(yt_1) - np.stack([x, y], -1).reshape(steps * steps, 2)) / dt dydt_abs = dydt.reshape(steps, steps, 2) dydt_unit = dydt_abs / np.linalg.norm(dydt_abs, axis=-1, keepdims=True) ax_vecfield.streamplot(x, y, dydt_unit[:, :, 0], dydt_unit[:, :, 1], color="black") ax_vecfield.set_xlim(-6, 6) ax_vecfield.set_ylim(-6, 6) ax_vec_error_abs.cla() ax_vec_error_abs.set_title('Abs. error of thetadot') ax_vec_error_abs.set_xlabel('theta') ax_vec_error_abs.set_ylabel('theta_dt') abs_dif = np.clip(np.linalg.norm(dydt_abs - dydt_ref, axis=-1), 0., 3.) c1 = ax_vec_error_abs.contourf(x, y, abs_dif, 100) plt.colorbar(c1, ax=ax_vec_error_abs) ax_vec_error_abs.set_xlim(-6, 6) ax_vec_error_abs.set_ylim(-6, 6) ax_vec_error_rel.cla() ax_vec_error_rel.set_title('Rel. error of thetadot') ax_vec_error_rel.set_xlabel('theta') ax_vec_error_rel.set_ylabel('theta_dt') rel_dif = np.clip(abs_dif / mag_ref, 0., 1.) c2 = ax_vec_error_rel.contourf(x, y, rel_dif, 100) plt.colorbar(c2, ax=ax_vec_error_rel) ax_vec_error_rel.set_xlim(-6, 6) ax_vec_error_rel.set_ylim(-6, 6) ax_energy.cla() ax_energy.set_title('Total Energy') ax_energy.set_xlabel('t') ax_energy.plot(np.arange(0., x_t.shape[0] * dt, dt), np.array([total_energy(x_) for x_ in x_t])) ax_energy.plot(np.arange(0., x_t.shape[0] * dt, dt), total_energy(x_t)) fig.tight_layout() plt.savefig(PLOT_DIR + '/{:03d}'.format(epoch)) plt.close() # Compute Metrics energy_drift_interp = relative_energy_drift(x_t, x_val) phase_error_interp = relative_phase_error(x_t, x_val) traj_err_interp = trajectory_error(x_t, x_val) wall_time = (datetime.datetime.now() - datetime.datetime.strptime( TIME_OF_RUN, "%Y%m%d-%H%M%S")).total_seconds() string = "{},{},{},{},{}\n".format(wall_time, epoch, energy_drift_interp, phase_error_interp, traj_err_interp) file_path = (PLOT_DIR + TIME_OF_RUN + "results" + str(args.lr) + str(args.dataset_size) + str(args.batch_size) + ".csv") if not os.path.isfile(file_path): title_string = "wall_time,epoch,energy_interp,phase_interp,traj_err_interp\n" fd = open(file_path, 'a') fd.write(title_string) fd.close() fd = open(file_path, 'a') fd.write(string) fd.close() # Print Jacobian if ode_model: np.set_printoptions(suppress=True, precision=4, linewidth=150) # The first Jacobian is averaged over 100 randomly sampled points from U(-1, 1) jac = tf.zeros((2, 2)) for i in range(100): with tf.GradientTape(persistent=True) as g: x = (2 * tf.random.uniform((1, 2)) - 1) g.watch(x) y = model(0, x) jac = jac + g.jacobian(y, x)[0, :, 0] print(jac.numpy() / 100) with tf.GradientTape(persistent=True) as g: x = tf.zeros([1, 2]) g.watch(x) y = model(0, x) print(g.jacobian(y, x)[0, :, 0]) if args.create_video: x1 = np.sin(x_t[:, 0]) y1 = -np.cos(x_t[:, 0]) fig = plt.figure() ax = fig.add_subplot(111, autoscale_on=False, xlim=(-2, 2), ylim=(-2, 2)) ax.set_aspect('equal') ax.grid() line, = ax.plot([], [], 'o-', lw=2) time_template = 'time = %.1fs' time_text = ax.text(0.05, 0.9, '', transform=ax.transAxes) def animate(i): thisx = [0, x1[i]] thisy = [0, y1[i]] line.set_data(thisx, thisy) time_text.set_text(time_template % (i * 0.01)) return line, time_text def init(): line.set_data([], []) time_text.set_text('') return line, time_text ani = animation.FuncAnimation(fig, animate, range(1, len(x1)), interval=dt * len(x1), blit=True, init_func=init) Writer = animation.writers['ffmpeg'] writer = Writer(fps=60, metadata=dict(artist='Me'), bitrate=2400) ani.save(PLOT_DIR + 'sp{}.mp4'.format(epoch), writer=writer) x1 = np.sin(x_val[:, 0]) y1 = -np.cos(x_val[:, 0]) fig = plt.figure() ax = fig.add_subplot(111, autoscale_on=False, xlim=(-2, 2), ylim=(-2, 2)) ax.set_aspect('equal') ax.grid() line, = ax.plot([], [], 'o-', lw=2) time_template = 'time = %.1fs' time_text = ax.text(0.05, 0.9, '', transform=ax.transAxes) ani = animation.FuncAnimation(fig, animate, range(1, len(x_t)), interval=dt * len(x_t), blit=True, init_func=init) Writer = animation.writers['ffmpeg'] writer = Writer(fps=60, metadata=dict(artist='Me'), bitrate=2400) ani.save(PLOT_DIR + 'sp_ref.mp4', writer=writer) plt.close()
def OdeintAdjointMethod(*args): global _arguments # args = _arguments.args # kwargs = _arguments.kwargs func = _arguments.func method = _arguments.method options = _arguments.options rtol = _arguments.rtol atol = _arguments.atol y0, t = args[:-1], args[-1] # registers `t` as a Variable that needs a grad, then resets it to a Tensor # for the `odeint` function to work. This is done to force tf to allow us to # pass the gradient of t as output. # t = tf.get_variable('t', initializer=t) # t = tf.convert_to_tensor(t, dtype=t.dtype) ans = odeint(func, y0, t, rtol=rtol, atol=atol, method=method, options=options) def grad(*grad_output, variables=None): global _arguments flat_params = _flatten(variables) func = _arguments.func method = _arguments.method options = _arguments.options rtol = _arguments.rtol atol = _arguments.atol n_tensors = len(ans) f_params = tuple(variables) # TODO: use a tf.keras.Model and call odeint_adjoint to implement higher order derivatives. def augmented_dynamics(t, y_aug): # Dynamics of the original system augmented with # the adjoint wrt y, and an integrator wrt t and args. y, adj_y = y_aug[:n_tensors], y_aug[ n_tensors:2 * n_tensors] # Ignore adj_time and adj_params. with tf.GradientTape() as tape: tape.watch(t) tape.watch(y) func_eval = func(t, y) func_eval = tf.convert_to_tensor(func_eval) gradys = -tf.stack(adj_y) if len(gradys.shape) < len(func_eval.shape): gradys = tf.expand_dims(gradys, axis=0) vjp_t, *vjp_y_and_params = tape.gradient( func_eval, (t, ) + y + f_params, output_gradients=gradys, unconnected_gradients=tf.UnconnectedGradients.ZERO) vjp_y = vjp_y_and_params[:n_tensors] vjp_params = vjp_y_and_params[n_tensors:] vjp_params = _flatten(vjp_params) if _check_len(f_params) == 0: vjp_params = tf.convert_to_tensor(0., dtype=vjp_y[0].dype) vjp_params = move_to_device(vjp_params, vjp_y[0].device) return (*func_eval, *vjp_y, vjp_t, vjp_params) T = ans[0].shape[0] if isinstance(grad_output, (tf.Tensor, tf.Variable)): adj_y = [grad_output[-1]] else: adj_y = tuple(grad_output_[-1] for grad_output_ in grad_output) adj_params = tf.zeros_like(flat_params, dtype=flat_params.dtype) adj_time = move_to_device(tf.convert_to_tensor(0., dtype=t.dtype), t.device) time_vjps = [] if hasattr(func, 'base_func') and hasattr(func.base_func, 'nfe'): nfe = func.base_func.nfe.numpy() for i in range(T - 1, 0, -1): ans_i = tuple(ans_[i] for ans_ in ans) if isinstance(grad_output, (tf.Tensor, tf.Variable)): grad_output_i = [grad_output[i]] else: grad_output_i = tuple(grad_output_[i] for grad_output_ in grad_output) func_i = func(t[i], ans_i) if not isinstance(func_i, Iterable): func_i = [func_i] # Compute the effect of moving the current time measurement point. dLd_cur_t = sum( tf.reshape( tf.matmul(tf.reshape(func_i_, [1, -1]), tf.reshape(grad_output_i_, [-1, 1])), [1]) for func_i_, grad_output_i_ in zip(func_i, grad_output_i)) adj_time = adj_time - dLd_cur_t time_vjps.append(dLd_cur_t) # Run the augmented system backwards in time. if isinstance(adj_params, Iterable): if _numel(adj_params) == 0: adj_params = move_to_device( tf.convert_to_tensor(0., dtype=adj_y[0].dtype), adj_y[0].device) aug_y0 = (*ans_i, *adj_y, adj_time, adj_params) aug_ans = odeint(augmented_dynamics, aug_y0, tf.convert_to_tensor([t[i], t[i - 1]]), rtol=rtol, atol=atol, method=method, options=options) # Unpack aug_ans. adj_y = aug_ans[n_tensors:2 * n_tensors] adj_time = aug_ans[2 * n_tensors] adj_params = aug_ans[2 * n_tensors + 1] adj_y = tuple(adj_y_[1] if _check_len(adj_y_) > 0 else adj_y_ for adj_y_ in adj_y) if _check_len(adj_time) > 0: adj_time = adj_time[1] if _check_len(adj_params) > 0: adj_params = adj_params[1] adj_y = tuple(adj_y_ + grad_output_[i - 1] for adj_y_, grad_output_ in zip(adj_y, grad_output)) del aug_y0, aug_ans if hasattr(func, 'base_func') and hasattr(func.base_func, 'nfe'): nbe = func.base_func.nfe.numpy() - nfe func.base_func.nfe.assign(nfe) func.base_func.nbe.assign(nbe) time_vjps.append(adj_time) time_vjps = tf.concat(time_vjps[::-1], 0) # reshape the parameters back into the correct variable shapes var_flat_lens = [_numel(v, dtype=tf.int32).numpy() for v in variables] var_shapes = [v.shape for v in variables] adj_params_splits = tf.split(adj_params, var_flat_lens) adj_params_list = [ tf.reshape(p, v_shape) for p, v_shape in zip(adj_params_splits, var_shapes) ] return (*adj_y, time_vjps), adj_params_list return ans, grad
dz_dt = y[0] * y[1] - self.beta * y[2] dL_dt = tf.stack([dx_dt, dy_dt, dz_dt]) return dL_dt t = tf.range(0.0, 100.0, 0.01, dtype=tf.float64) initial_state = tf.convert_to_tensor([1., 1., 1.], dtype=tf.float64) sigma = 10. beta = 8. / 3. rho = 28. with tf.device(device): t1 = time.time() solution = odeint(Lorenz(sigma, beta, rho), initial_state, t).numpy() t2 = time.time() print("Finished integrating ! Result shape :", solution.shape) print("Time required (s): ", t2 - t1) from mpl_toolkits.mplot3d import Axes3D # needed for plotting in 3d _ = Axes3D fig = plt.figure(figsize=(16, 16)) ax = fig.gca(projection='3d') ax.set_title('Lorenz Attractor') ax.set_xlabel('X') ax.set_ylabel('Y') ax.set_zlabel('Z') ax.plot(solution[:, 0], solution[:, 1], solution[:, 2])
samp_ts = states['samp_ts'] print('Loaded ckpt from {}'.format(path)) for itr in range(1, args.niters + 1): # backward in time to infer q(z_0) with tf.GradientTape() as tape: h = rec.initHidden() for t in reversed(range(samp_trajs.shape[1])): obs = samp_trajs[:, t, :] out, h = rec(obs, h) qz0_mean, qz0_logvar = out[:, :latent_dim], out[:, latent_dim:] epsilon = tf.convert_to_tensor(np.random.randn(*qz0_mean.shape.as_list()), dtype=qz0_mean.dtype) z0 = epsilon * tf.exp(.5 * qz0_logvar) + qz0_mean # forward in time and solve ode for reconstructions pred_z = tf.transpose(odeint(func, z0, samp_ts), [1, 0, 2]) pred_x = dec(pred_z) # compute loss noise_std_ = tf.zeros(pred_x.shape, dtype=tf.float64) + noise_std noise_logvar = 2. * tf.log(noise_std_) logpx = tf.reduce_sum(log_normal_pdf( samp_trajs, pred_x, noise_logvar), axis=-1) logpx = tf.reduce_sum(logpx, axis=-1) pz0_mean = pz0_logvar = tf.zeros(z0.shape, dtype=tf.float64) analytic_kl = tf.reduce_sum(normal_kl(qz0_mean, qz0_logvar, pz0_mean, pz0_logvar), axis=-1) loss = tf.reduce_mean(-logpx + analytic_kl, axis=0) params = (list(func.variables) + list(dec.variables) + list(rec.variables)) grad = tape.gradient(loss, params)
device = 'gpu:' + str(args.gpu) if tf.test.is_gpu_available() else 'cpu:0' true_y0 = tf.convert_to_tensor([[0.5, 0.01]], dtype=tf.float64) t = tf.linspace(0., 25., args.data_size) true_A = tf.convert_to_tensor([[-0.1, 3.0], [-3.0, -0.1]], dtype=tf.float64) class Lambda(tf.keras.Model): def call(self, t, y): return tf.matmul(y, true_A) with tf.device(device): t1 = time.time() true_y = odeint(Lambda(), true_y0, t, method=args.method) t2 = time.time() print(true_y) print() print("Time taken to compute solution : ", t2 - t1) def get_batch(): s = np.random.choice(np.arange(args.data_size - args.batch_time, dtype=np.int64), args.batch_size, replace=False) temp_y = true_y.numpy() batch_y0 = tf.convert_to_tensor(temp_y[s]) # (M, D) batch_t = t[:args.batch_time] # (T)