def taylor_test(J, c, dc, dJdm=None): """ Dummy `taylor_test` function for consistency of notation between discrete and continuous adjoint problems. """ assert dJdm is not None if isinstance(c, Function): c = c.dat.data elif isinstance(c, list) and isinstance(c[0], Function): c = np.array([ci.dat.data[0] for ci in c]) if isinstance(dc, Function): dc = c.dat.data elif isinstance(dc, list) and isinstance(dc[0], Function): dc = np.array([dci.dat.data[0] for dci in dc]) opt.taylor_test(J, dJdm, c, delta_m=dc, verbose=True)
if bool(args.taylor_test_okada or False): """ Consider the reduced functional J(m) = e . S(m) Then dJdm is the same as propagating e through the reverse mode of AD on S. """ print("Taylor test Okada...") # np.random.seed(0) _rf_okada = lambda m: np.sum(okada_source(m)) _gradient_okada = lambda _: gradient_okada(_, np.ones(len(op.indices))) m_init = 0.7 * np.concatenate( [op.control_parameters[ctrl] for ctrl in op.active_controls]) minconv = opt.taylor_test(_rf_okada, _gradient_okada, m_init, verbose=True) assert minconv > 1.90 # --- Setup coupling def tsunami_ic(dislocation): """ Set the initial velocity-elevation tuple for the tsunami propagation model, given some dislocation field. """ surf = Function(P1) surf.dat.data[op.indices] = dislocation return surf
savefig("original_source", "plots", extensions=["jpg"]) fig, axes = plt.subplots(figsize=(7, 7)) cbar = fig.colorbar(axes.contourf(X, Y, eta_pert, **plotting_kwargs), ax=axes) cbar.ax.tick_params(labelsize=tick_fontsize) cbar.set_label(r"Elevation [$\mathrm m$]", fontsize=fontsize) axes.set_xlabel("Longitude", fontsize=fontsize) axes.set_ylabel("Latitude", fontsize=fontsize) use_degrees(axes) for tick in axes.xaxis.get_major_ticks(): tick.label.set_fontsize(tick_fontsize) for tick in axes.yaxis.get_major_ticks(): tick.label.set_fontsize(tick_fontsize) savefig("perturbed_source", "plots", extensions=["jpg"]) # Taylor test taylor_test(reduced_functional, gradient, op.input_vector, verbose=True) def opt_cb(m): """ Print progress after every successful line search. """ msg = "{:4d}: J = {:.4e} ||dJdm|| = {:.4e}" counter = len(op.J_progress) if counter % 100 == 0: print(msg.format(counter, op.J_progress[-1], op.dJdm_progress[-1])) # Inversion op.J_progress = [] op.dJdm_progress = []