Ejemplo n.º 1
0
# Making data
x = jnp.linspace(-3, 4, 50)
t = jnp.linspace(0.5, 5.0, 20)
t_grid, x_grid = jnp.meshgrid(t, x, indexing="ij")
u = burgers(x_grid, t_grid, 0.1, 1.0)

X = jnp.concatenate([t_grid.reshape(-1, 1), x_grid.reshape(-1, 1)], axis=1)
y = u.reshape(-1, 1)
y += noise * jnp.std(y) * random.normal(key, y.shape)

# Defning model and optimizers
model = Deepmod([30, 30, 30, 1])
optimizer_def = optim.Adam(learning_rate=2e-3, beta1=0.99, beta2=0.99)

# Running warm restart bayes
update_fn = create_update(loss_fn_SBL, (model, X, y, True))
for run_idx, subkey in enumerate(random.split(key, n_runs)):
    print(f"Starting SBL run {run_idx}")
    variables = model.init(subkey, X)
    state, params = variables.pop("params")
    state = (state, {"prior_init": None})  # adding prior to state
    optimizer = optimizer_def.create(params)
    train_max_iter(
        update_fn,
        optimizer,
        state,
        max_iterations,
        log_dir=script_dir + f"sbl_run_{run_idx}/",
    )
Ejemplo n.º 2
0
# Making data
x = jnp.linspace(-3, 4, 50)
t = jnp.linspace(0.5, 5.0, 20)
t_grid, x_grid = jnp.meshgrid(t, x, indexing="ij")
u = burgers(x_grid, t_grid, 0.1, 1.0)

X = jnp.concatenate([t_grid.reshape(-1, 1), x_grid.reshape(-1, 1)], axis=1)
y = u.reshape(-1, 1)
y += noise * jnp.std(y) * random.normal(key, y.shape)

# Defning model and optimizers
model = Deepmod([30, 30, 30, 1])
optimizer_def = optim.Adam(learning_rate=2e-3, beta1=0.99, beta2=0.99)

# Running PINN with overcomplete library
update_fn = create_update(loss_fn_pinn, (model, X, y))
for run_idx, subkey in enumerate(random.split(key, n_runs)):
    print(f"Starting multitask run {run_idx}")
    variables = model.init(subkey, X)
    state, params = variables.pop("params")
    optimizer = optimizer_def.create(params)
    train_max_iter(
        update_fn,
        optimizer,
        state,
        max_iterations,
        log_dir=script_dir + f"burgers_pinn_run_{run_idx}/",
    )

# Running bayes with overcomplete library
update_fn = create_update(loss_fn_bayesian_ridge, (model, X, y, True))
Ejemplo n.º 3
0
u = burgers(x_grid, t_grid, 0.1, 1.0)

X = jnp.concatenate([t_grid.reshape(-1, 1), x_grid.reshape(-1, 1)], axis=1)
y = u.reshape(-1, 1)
y += 0.10 * jnp.std(y) * random.normal(key, y.shape)

# %% Building model and params
model = Deepmod([30, 30, 30, 1])
variables = model.init(key, X)

optimizer = optim.Adam(learning_rate=2e-3, beta1=0.99, beta2=0.99)
state, params = variables.pop("params")
optimizer = optimizer.create(params)

state = (state, {"prior_init": None})  # adding prior to state
update_fn = create_update(loss_fn_SBL, (model, X, y, False))

# optimizer, state = train_max_iter(update_fn, optimizer, state, 10000)

grad_fn = jax.value_and_grad(loss_fn_SBL, has_aux=True)
(loss, (updated_state, metrics, output)), grad = grad_fn(
    optimizer.target, state, model, X, y
)


# %%

# %%
model_state, loss_state = state
variables = {"params": params, **model_state}
(prediction, dt, theta, coeffs), updated_model_state = model.apply(
Ejemplo n.º 4
0
    u = burgers(x_grid, t_grid, 0.1, 1.0)

    X = jnp.concatenate([t_grid.reshape(-1, 1), x_grid.reshape(-1, 1)], axis=1)
    y = u.reshape(-1, 1)
    y += 0.10 * jnp.std(y) * random.normal(key, y.shape)

elif dataset == "kdv":
    key = random.PRNGKey(42)
    x = jnp.linspace(-10, 10, 100)
    t = jnp.linspace(0.1, 1.0, 20)
    t_grid, x_grid = jnp.meshgrid(t, x, indexing="ij")
    u = doublesoliton(x_grid, t_grid, c=[5.0, 2.0], x0=[0.0, -5.0])

    X = jnp.concatenate([t_grid.reshape(-1, 1), x_grid.reshape(-1, 1)], axis=1)
    y = u.reshape(-1, 1)
    y += 0.10 * jnp.std(y) * random.normal(key, y.shape)
else:
    raise NotImplementedError
# %% Building model and params
model = Deepmod([30, 30, 30, 1], (5, 4))
variables = model.init(key, X)

optimizer = optim.Adam(learning_rate=2e-3, beta1=0.99, beta2=0.99)
state, params = variables.pop("params")
optimizer = optimizer.create(params)
update_fn = create_update(loss_fn_pinn, (model, X, y, 1.0))
# %%
optimizer, state = train_max_iter(update_fn, optimizer, state, 5000)

# %%
Ejemplo n.º 5
0
# Making data
x = jnp.linspace(-3, 4, 50)
t = jnp.linspace(0.1, 5.0, 2)
t_grid, x_grid = jnp.meshgrid(t, x, indexing="ij")
u = burgers(x_grid, t_grid, 0.1, 1.0)

X = jnp.concatenate([t_grid.reshape(-1, 1), x_grid.reshape(-1, 1)], axis=1)
y = u.reshape(-1, 1)
y += noise * jnp.std(y) * random.normal(key, y.shape)

# Defning model and optimizers
model = Deepmod([30, 30, 30, 1])
optimizer_def = optim.Adam(learning_rate=2e-3, beta1=0.99, beta2=0.99)

# Running multitask
update_fn = create_update(loss_fn_multitask_precalc, (model, X, y))
for run_idx, subkey in enumerate(random.split(key, n_runs)):
    print(f"Starting multitask run {run_idx}")
    variables = model.init(subkey, X)
    state, params = variables.pop("params")
    optimizer = optimizer_def.create(params)
    train_max_iter(
        update_fn,
        optimizer,
        state,
        max_iterations,
        log_dir=script_dir + f"multitask_run_{run_idx}/",
    )

# Running bayesian multitask
update_fn = create_update(loss_fn_pinn_bayes_mse_hyperprior, (model, X, y))
Ejemplo n.º 6
0
x = jnp.linspace(-3, 4, 50)
t = jnp.linspace(0.5, 5.0, 20)
t_grid, x_grid = jnp.meshgrid(t, x, indexing="ij")
u = burgers(x_grid, t_grid, 0.1, 1.0)

X = jnp.concatenate([t_grid.reshape(-1, 1), x_grid.reshape(-1, 1)], axis=1)
y = u.reshape(-1, 1)
y += 0.10 * jnp.std(y) * random.normal(key, y.shape)
X_train, X_test, y_train, y_test = train_test_split(
    X, y, test_size=0.2, random_state=42
)
# %% Building model and params
model = Deepmod([30, 30, 30, 1])
variables = model.init(key, X)

optimizer = optim.Adam(learning_rate=2e-3, beta1=0.99, beta2=0.99)
state, params = variables.pop("params")
optimizer = optimizer.create(params)
update_fn = create_update(loss_fn_pinn, (model, X_train, y_train, 1.0))

# Validation loss is mse on testset.
val_fn = jit(
    lambda opt, state: loss_fn_mse(opt.target, state, model, X_test, y_test)[0]
)
# %%
optimizer, state = train_early_stop(
    update_fn, val_fn, optimizer, state, max_epochs=10000, delta=0.0, patience=2000
)

# %%