コード例 #1
0
ファイル: loss.py プロジェクト: dalerxli/MultiScale-PINN
from jax import random

import os, sys
sys.path.append(
    os.path.dirname(os.path.dirname(os.path.dirname(
        os.path.abspath(__file__)))))

from models import normalized_model
from jaxmeta.loss import l1_regularization, l2_regularization
from jaxmeta.grad import jacobian_fn, hessian_fn

from data import domain, epsilon
from config import metaloss

model = normalized_model(domain)
jacobian = jacobian_fn(model)

a = 3.0


@jax.jit
def loss_fn_(params, batch):
    collocation, dirichlet = batch["collocation"], batch["dirichlet"]

    if collocation[0] is not None:
        uv = model(params, jnp.hstack([collocation.x, collocation.t]))
        u, v = uv[:, 0:1], uv[:, 1:2]
        duv_dxt = jacobian(params, jnp.hstack([collocation.x, collocation.t]))
        du_dt, dv_dt = duv_dxt[:, 0:1, 1], duv_dxt[:, 1:2, 1]
        du_dx, dv_dx = duv_dxt[:, 0:1, 0], duv_dxt[:, 1:2, 0]
        loss_c1 = metaloss(du_dt + dv_dx, 0)
コード例 #2
0
ファイル: loss.py プロジェクト: dalerxli/MultiScale-PINN
from jaxmeta.model_init import init_siren_params, init_tanh_params
from models import simple_model, normalized_model, tanh_model
from jaxmeta.loss import l1_regularization, l2_regularization
from jaxmeta.grad import jacobian_fn, hessian_fn

from data import domain, epsilon
import config

key, *subkeys = random.split(config.key, 3)
direct_params = init_siren_params(subkeys[0], config.direct_layers,
                                  config.direct_c0, config.direct_w0)
inverse_params = jnp.array([2.0])

direct_model = normalized_model(domain)
jacobian = jacobian_fn(direct_model)


@jax.jit
def inverse_model(params, x):
    return 1 + params[0] / jnp.pi * jnp.cos(2 * jnp.pi * x)


@jax.jit
def rhs(params, xt):
    direct_params, inverse_params = params
    duv_dxt = jacobian(direct_params, xt)
    du_dx = duv_dxt[:, 0]
    a = inverse_model(inverse_params, xt[0])
    return a * du_dx
コード例 #3
0
from jaxmeta.model_init import init_siren_params
from models import simple_model, normalized_model, tanh_model
from jaxmeta.loss import l1_regularization, l2_regularization
from jaxmeta.grad import jacobian_fn, hessian_fn

from data import domain
import config

key, *subkeys = random.split(config.key, 3)
direct_params = init_siren_params(subkeys[0], config.direct_layers,
                                  config.direct_c0, config.direct_w0)
inverse_params = init_siren_params(subkeys[1], config.inverse_layers,
                                   config.inverse_c0, config.inverse_w0)

direct_model = normalized_model(domain)
jacobian_direct = jacobian_fn(direct_model)

inverse_model = simple_model()

# @jax.jit
# def inverse_model(params, x):
# 	return 1 + jnp.exp(-(x-0.5)**2)

hessian_inv = hessian_fn(inverse_model)


@jax.jit
def lhs_operator(params, x):
    direct_params, inverse_params = params
    a = inverse_model(inverse_params, x)
    dc_dx = jacobian_direct(direct_params, x)[:, 0]
コード例 #4
0
ファイル: loss.py プロジェクト: dalerxli/MultiScale-PINN
from jaxmeta.model_init import init_siren_params, init_tanh_params
from models import simple_model, normalized_model, tanh_model
from jaxmeta.loss import l1_regularization, l2_regularization
from jaxmeta.grad import jacobian_fn, hessian_fn

from data import domain, epsilon
import config

key, *subkeys = random.split(config.key, 3)
direct_params = init_siren_params(subkeys[0], config.direct_layers,
                                  config.direct_c0, config.direct_w0)
inverse_params = jnp.array([2.0])

direct_model = normalized_model(domain)
jacobian = jacobian_fn(direct_model)


@jax.jit
def inverse_model(params, x):
    return 1 + params / jnp.pi * jnp.cos(2 * jnp.pi * x)


params = [direct_params, inverse_params]


@jax.jit
def loss_fn_(params, batch):
    collocation, dirichlet = batch["collocation"], batch["dirichlet"]
    direct_params, inverse_params = params