from jax import random import os, sys sys.path.append( os.path.dirname(os.path.dirname(os.path.dirname( os.path.abspath(__file__))))) from models import normalized_model from jaxmeta.loss import l1_regularization, l2_regularization from jaxmeta.grad import jacobian_fn, hessian_fn from data import domain, epsilon from config import metaloss model = normalized_model(domain) jacobian = jacobian_fn(model) a = 3.0 @jax.jit def loss_fn_(params, batch): collocation, dirichlet = batch["collocation"], batch["dirichlet"] if collocation[0] is not None: uv = model(params, jnp.hstack([collocation.x, collocation.t])) u, v = uv[:, 0:1], uv[:, 1:2] duv_dxt = jacobian(params, jnp.hstack([collocation.x, collocation.t])) du_dt, dv_dt = duv_dxt[:, 0:1, 1], duv_dxt[:, 1:2, 1] du_dx, dv_dx = duv_dxt[:, 0:1, 0], duv_dxt[:, 1:2, 0] loss_c1 = metaloss(du_dt + dv_dx, 0)
from jaxmeta.model_init import init_siren_params, init_tanh_params from models import simple_model, normalized_model, tanh_model from jaxmeta.loss import l1_regularization, l2_regularization from jaxmeta.grad import jacobian_fn, hessian_fn from data import domain, epsilon import config key, *subkeys = random.split(config.key, 3) direct_params = init_siren_params(subkeys[0], config.direct_layers, config.direct_c0, config.direct_w0) inverse_params = jnp.array([2.0]) direct_model = normalized_model(domain) jacobian = jacobian_fn(direct_model) @jax.jit def inverse_model(params, x): return 1 + params[0] / jnp.pi * jnp.cos(2 * jnp.pi * x) @jax.jit def rhs(params, xt): direct_params, inverse_params = params duv_dxt = jacobian(direct_params, xt) du_dx = duv_dxt[:, 0] a = inverse_model(inverse_params, xt[0]) return a * du_dx
from jaxmeta.model_init import init_siren_params from models import simple_model, normalized_model, tanh_model from jaxmeta.loss import l1_regularization, l2_regularization from jaxmeta.grad import jacobian_fn, hessian_fn from data import domain import config key, *subkeys = random.split(config.key, 3) direct_params = init_siren_params(subkeys[0], config.direct_layers, config.direct_c0, config.direct_w0) inverse_params = init_siren_params(subkeys[1], config.inverse_layers, config.inverse_c0, config.inverse_w0) direct_model = normalized_model(domain) jacobian_direct = jacobian_fn(direct_model) inverse_model = simple_model() # @jax.jit # def inverse_model(params, x): # return 1 + jnp.exp(-(x-0.5)**2) hessian_inv = hessian_fn(inverse_model) @jax.jit def lhs_operator(params, x): direct_params, inverse_params = params a = inverse_model(inverse_params, x) dc_dx = jacobian_direct(direct_params, x)[:, 0]
from jaxmeta.model_init import init_siren_params, init_tanh_params from models import simple_model, normalized_model, tanh_model from jaxmeta.loss import l1_regularization, l2_regularization from jaxmeta.grad import jacobian_fn, hessian_fn from data import domain, epsilon import config key, *subkeys = random.split(config.key, 3) direct_params = init_siren_params(subkeys[0], config.direct_layers, config.direct_c0, config.direct_w0) inverse_params = jnp.array([2.0]) direct_model = normalized_model(domain) jacobian = jacobian_fn(direct_model) @jax.jit def inverse_model(params, x): return 1 + params / jnp.pi * jnp.cos(2 * jnp.pi * x) params = [direct_params, inverse_params] @jax.jit def loss_fn_(params, batch): collocation, dirichlet = batch["collocation"], batch["dirichlet"] direct_params, inverse_params = params