def generate_dataset(n_i, n_b, n_cx, n_ct, n_dx, n_dt): x_i = jnp.linspace(*domain[:, 0], n_i).reshape((-1, 1)) t_i = jnp.zeros_like(x_i) u_i = u0_fn(x_i, t_i) v_i = v0_fn(x_i, t_i) x_l = jnp.ones((n_b, 1))*domain[0, 0] x_r = jnp.ones((n_b, 1))*domain[1, 0] t_b = jnp.linspace(*domain[:, 1], n_b).reshape((-1, 1)) u_l = ul_fn(x_l, t_b) u_r = ur_fn(x_r, t_b) v_l = vl_fn(x_l, t_b) v_r = vr_fn(x_r, t_b) x_c = jnp.linspace(*domain[:, 0], n_cx).reshape((-1, 1)) t_c = jnp.linspace(*domain[:, 1], n_ct).reshape((-1, 1)) xt_c = tensor_grid([x_c, t_c]) data = loadmat(data_file) u_d, v_d, x_d, t_d = data["u_snapshots"], data["v_snapshots"], data["x"], data["t"].T u_d = u_d[1:-1:(len(x_d)//n_dx), 1:-1:(len(t_d)//n_dt)].reshape((-1, 1)) v_d = v_d[1:-1:(len(x_d)//n_dx), 1:-1:(len(t_d)//n_dt)].reshape((-1, 1)) x_d = x_d[1:-1:(len(x_d)//n_dx)] t_d = t_d[1:-1:(len(t_d)//n_dt)] xt_d = tensor_grid([x_d, t_d]) dirichlet = dataset_Dirichlet(jnp.vstack([x_i, x_l, x_r, xt_d[:, 0:1]]), jnp.vstack([t_i, t_b, t_b, xt_d[:, 1:2]]), jnp.vstack([u_i, u_l, u_r, u_d]), jnp.vstack([v_i, v_l, v_r, v_d])) collocation = dataset_Collocation(jnp.vstack([xt_c[:, 0:1], xt_d[:, 0:1]]), jnp.vstack([xt_c[:, 1:2], xt_d[:, 1:2]])) return dirichlet, collocation
def generate_dataset(n_i, n_cx, n_ct, n_quad): nodes, weights = np.polynomial.legendre.leggauss(n_quad) nodes = jnp.array(0.5*(nodes+1), dtype = jnp.float32).reshape((-1, 1)) weights = jnp.array(0.5*weights, dtype = jnp.float32).reshape((-1, 1)) x_i = jnp.linspace(*domain[:, 0], n_i).reshape((-1, 1)) v_i = nodes xv_i = tensor_grid([x_i, v_i]) x_i, v_i = xv_i[:, 0:1], xv_i[:, 1:2] t_i = jnp.zeros_like(x_i) r_i = r0_fn(x_i, t_i, v_i) j_i = j0_fn(x_i, t_i, v_i) x_c = jnp.linspace(*domain[:, 0], n_cx).reshape((-1, 1)) t_c = jnp.linspace(*domain[:, 1], n_ct).reshape((-1, 1)) v_c = nodes xtv_c = tensor_grid([x_c, t_c, v_c]) dirichlet = dataset_Dirichlet(x_i, t_i, v_i, r_i, j_i) collocation = dataset_Collocation(xtv_c[:, 0:1], xtv_c[:, 1:2], xtv_c[:, 2:3]) quad = dataset_Quadrature(nodes, weights) return dirichlet, collocation, quad
def generate_dataset(n_i, n_b, n_cx, n_ct): x_i = jnp.linspace(*domain[:, 0], n_i).reshape((-1, 1)) t_i = jnp.zeros_like(x_i) u_i = u0_fn(x_i, t_i) v_i = v0_fn(x_i, t_i) x_l = jnp.ones((n_b, 1))*domain[0, 0] x_r = jnp.ones((n_b, 1))*domain[1, 0] t_b = jnp.linspace(*domain[:, 1], n_b).reshape((-1, 1)) u_l = ul_fn(x_l, t_b) u_r = ur_fn(x_r, t_b) v_l = vl_fn(x_l, t_b) v_r = vr_fn(x_r, t_b) x_c = jnp.linspace(*domain[:, 0], n_cx).reshape((-1, 1)) t_c = jnp.linspace(*domain[:, 1], n_ct).reshape((-1, 1)) xt_c = tensor_grid([x_c, t_c]) dirichlet = dataset_Dirichlet(jnp.vstack([x_i, x_l, x_r]), jnp.vstack([t_i, t_b, t_b]), jnp.vstack([u_i, u_l, u_r]), jnp.vstack([v_i, v_l, v_r])) collocation = dataset_Collocation(xt_c[:, 0:1], xt_c[:, 1:2]) return dirichlet, collocation
sys.path.append( os.path.dirname(os.path.dirname(os.path.dirname( os.path.abspath(__file__))))) from loss import model from jaxmeta.data import load_params, tensor_grid import config from run import n_ib, n_c from data import domain from scipy.io import loadmat data_true = loadmat("epsilon_1e-12.mat") u_true, v_true, x_test = data_true["u"], data_true["v"], data_true["x"] t_test = [domain[1, 1]] xt_test = tensor_grid([x_test, t_test]) errors = np.zeros((len(n_ib), len(n_c))) for i, b in enumerate(n_ib): for j, c in enumerate(n_c): path = "models/ib_{}_c_{}/iteration_{}/params.npy".format( i, j, config.iterations) params = load_params(path) uv_pred = model(params, xt_test) errors[i, j] = config.metaloss(u_true, uv_pred[:, 0:1]) # print(errors) import matplotlib.pyplot as plt f, ax = plt.subplots(1, 1, figsize=(10, 10)) n_c_ = [c["cx"] * c["ct"] for c in n_c]