Example #1
0
    return jnp.concatenate([alpha, beta[jnp.newaxis]], axis=0)


key = random.PRNGKey(42)
x = jnp.linspace(-10, 10, 100)
t = jnp.linspace(0.1, 1.0, 10)
t_grid, x_grid = jnp.meshgrid(t, x, indexing="ij")
u = doublesoliton(x_grid, t_grid, c=[5.0, 2.0], x0=[0.0, -5.0])

X = jnp.concatenate([t_grid.reshape(-1, 1), x_grid.reshape(-1, 1)], axis=1)
y = u.reshape(-1, 1)
y += 0.10 * jnp.std(y) * random.normal(key, y.shape)

# %% Building model and params
model = Deepmod([30, 30, 30, 1])
variables = model.init(key, X)

prediction, dt, theta, coeffs = model.apply(variables, X)
y = dt
X = theta

n_samples, n_features = theta.shape
prior_params_mse = (0.0, 0.0)
tau = precision(y, prediction, *prior_params_mse)

alpha_prior = (1e-6, 1e-6)
beta_prior = (n_samples / 2, n_samples / (2 * jax.lax.stop_gradient(tau)))

n_samples, n_features = X.shape
norm_weight = jnp.concatenate((jnp.ones((n_features, )), jnp.zeros((1, ))),
Example #2
0
noise = 0.50
n_runs = 1
max_iterations = 100000

# Making data
x = jnp.linspace(-3, 4, 50)
t = jnp.linspace(0.5, 5.0, 20)
t_grid, x_grid = jnp.meshgrid(t, x, indexing="ij")
u = burgers(x_grid, t_grid, 0.1, 1.0)

X = jnp.concatenate([t_grid.reshape(-1, 1), x_grid.reshape(-1, 1)], axis=1)
y = u.reshape(-1, 1)
y += noise * jnp.std(y) * random.normal(key, y.shape)

# Defning model and optimizers
model = Deepmod([30, 30, 30, 1])
optimizer_def = optim.Adam(learning_rate=2e-3, beta1=0.99, beta2=0.99)

# Running warm restart bayes
update_fn = create_update(loss_fn_SBL, (model, X, y, True))
for run_idx, subkey in enumerate(random.split(key, n_runs)):
    print(f"Starting SBL run {run_idx}")
    variables = model.init(subkey, X)
    state, params = variables.pop("params")
    state = (state, {"prior_init": None})  # adding prior to state
    optimizer = optimizer_def.create(params)
    train_max_iter(
        update_fn,
        optimizer,
        state,
        max_iterations,
from modax.linear_model.SBL import SBL

# %% Making data
key = random.PRNGKey(42)

x = jnp.linspace(-3, 4, 50)
t = jnp.linspace(0.5, 5.0, 20)
t_grid, x_grid = jnp.meshgrid(t, x, indexing="ij")
u = burgers(x_grid, t_grid, 0.1, 1.0)

X = jnp.concatenate([t_grid.reshape(-1, 1), x_grid.reshape(-1, 1)], axis=1)
y = u.reshape(-1, 1)
y += 0.10 * jnp.std(y) * random.normal(key, y.shape)

# %% Building model and params
model = Deepmod([30, 30, 30, 1])
variables = model.init(key, X)

optimizer = optim.Adam(learning_rate=2e-3, beta1=0.99, beta2=0.99)
state, params = variables.pop("params")
optimizer = optimizer.create(params)

state = (state, {"prior_init": None})  # adding prior to state
update_fn = create_update(loss_fn_SBL, (model, X, y, False))

# optimizer, state = train_max_iter(update_fn, optimizer, state, 10000)

grad_fn = jax.value_and_grad(loss_fn_SBL, has_aux=True)
(loss, (updated_state, metrics, output)), grad = grad_fn(
    optimizer.target, state, model, X, y
)
Example #4
0
from modax.data.burgers import burgers
from time import time

# Making dataset
x = jnp.linspace(-3, 4, 100)
t = jnp.linspace(0.5, 5.0, 20)

t_grid, x_grid = jnp.meshgrid(t, x, indexing="ij")
u = burgers(x_grid, t_grid, 0.1, 1.0)

X_train = jnp.concatenate([t_grid.reshape(-1, 1), x_grid.reshape(-1, 1)], axis=1)
y_train = u.reshape(-1, 1)

# Instantiating model and optimizers
model = Deepmod(features=[50, 50, 1])
key = random.PRNGKey(42)
params = model.init(key, X_train)
optimizer = optim.Adam(learning_rate=2e-3, beta1=0.99, beta2=0.99)
optimizer = optimizer.create(params)

# Compiling train step
update = create_update(loss_fn_mse, model=model, x=X_train, y=y_train)
_ = update(optimizer)  # triggering compilation

# Running to convergence
max_epochs = 10001
t_start = time()
for i in jnp.arange(max_epochs):
    optimizer, loss = update(optimizer)
    if i % 1000 == 0:
Example #5
0
    u = burgers(x_grid, t_grid, 0.1, 1.0)

    X = jnp.concatenate([t_grid.reshape(-1, 1), x_grid.reshape(-1, 1)], axis=1)
    y = u.reshape(-1, 1)
    y += 0.10 * jnp.std(y) * random.normal(key, y.shape)

elif dataset == "kdv":
    key = random.PRNGKey(42)
    x = jnp.linspace(-10, 10, 100)
    t = jnp.linspace(0.1, 1.0, 20)
    t_grid, x_grid = jnp.meshgrid(t, x, indexing="ij")
    u = doublesoliton(x_grid, t_grid, c=[5.0, 2.0], x0=[0.0, -5.0])

    X = jnp.concatenate([t_grid.reshape(-1, 1), x_grid.reshape(-1, 1)], axis=1)
    y = u.reshape(-1, 1)
    y += 0.10 * jnp.std(y) * random.normal(key, y.shape)
else:
    raise NotImplementedError
# %% Building model and params
model = Deepmod([30, 30, 30, 1], (5, 4))
variables = model.init(key, X)

optimizer = optim.Adam(learning_rate=2e-3, beta1=0.99, beta2=0.99)
state, params = variables.pop("params")
optimizer = optimizer.create(params)
update_fn = create_update(loss_fn_pinn, (model, X, y, 1.0))
# %%
optimizer, state = train_max_iter(update_fn, optimizer, state, 5000)

# %%