return jnp.concatenate([alpha, beta[jnp.newaxis]], axis=0) key = random.PRNGKey(42) x = jnp.linspace(-10, 10, 100) t = jnp.linspace(0.1, 1.0, 10) t_grid, x_grid = jnp.meshgrid(t, x, indexing="ij") u = doublesoliton(x_grid, t_grid, c=[5.0, 2.0], x0=[0.0, -5.0]) X = jnp.concatenate([t_grid.reshape(-1, 1), x_grid.reshape(-1, 1)], axis=1) y = u.reshape(-1, 1) y += 0.10 * jnp.std(y) * random.normal(key, y.shape) # %% Building model and params model = Deepmod([30, 30, 30, 1]) variables = model.init(key, X) prediction, dt, theta, coeffs = model.apply(variables, X) y = dt X = theta n_samples, n_features = theta.shape prior_params_mse = (0.0, 0.0) tau = precision(y, prediction, *prior_params_mse) alpha_prior = (1e-6, 1e-6) beta_prior = (n_samples / 2, n_samples / (2 * jax.lax.stop_gradient(tau))) n_samples, n_features = X.shape norm_weight = jnp.concatenate((jnp.ones((n_features, )), jnp.zeros((1, ))),
noise = 0.50 n_runs = 1 max_iterations = 100000 # Making data x = jnp.linspace(-3, 4, 50) t = jnp.linspace(0.5, 5.0, 20) t_grid, x_grid = jnp.meshgrid(t, x, indexing="ij") u = burgers(x_grid, t_grid, 0.1, 1.0) X = jnp.concatenate([t_grid.reshape(-1, 1), x_grid.reshape(-1, 1)], axis=1) y = u.reshape(-1, 1) y += noise * jnp.std(y) * random.normal(key, y.shape) # Defning model and optimizers model = Deepmod([30, 30, 30, 1]) optimizer_def = optim.Adam(learning_rate=2e-3, beta1=0.99, beta2=0.99) # Running warm restart bayes update_fn = create_update(loss_fn_SBL, (model, X, y, True)) for run_idx, subkey in enumerate(random.split(key, n_runs)): print(f"Starting SBL run {run_idx}") variables = model.init(subkey, X) state, params = variables.pop("params") state = (state, {"prior_init": None}) # adding prior to state optimizer = optimizer_def.create(params) train_max_iter( update_fn, optimizer, state, max_iterations,
from modax.linear_model.SBL import SBL # %% Making data key = random.PRNGKey(42) x = jnp.linspace(-3, 4, 50) t = jnp.linspace(0.5, 5.0, 20) t_grid, x_grid = jnp.meshgrid(t, x, indexing="ij") u = burgers(x_grid, t_grid, 0.1, 1.0) X = jnp.concatenate([t_grid.reshape(-1, 1), x_grid.reshape(-1, 1)], axis=1) y = u.reshape(-1, 1) y += 0.10 * jnp.std(y) * random.normal(key, y.shape) # %% Building model and params model = Deepmod([30, 30, 30, 1]) variables = model.init(key, X) optimizer = optim.Adam(learning_rate=2e-3, beta1=0.99, beta2=0.99) state, params = variables.pop("params") optimizer = optimizer.create(params) state = (state, {"prior_init": None}) # adding prior to state update_fn = create_update(loss_fn_SBL, (model, X, y, False)) # optimizer, state = train_max_iter(update_fn, optimizer, state, 10000) grad_fn = jax.value_and_grad(loss_fn_SBL, has_aux=True) (loss, (updated_state, metrics, output)), grad = grad_fn( optimizer.target, state, model, X, y )
from modax.data.burgers import burgers from time import time # Making dataset x = jnp.linspace(-3, 4, 100) t = jnp.linspace(0.5, 5.0, 20) t_grid, x_grid = jnp.meshgrid(t, x, indexing="ij") u = burgers(x_grid, t_grid, 0.1, 1.0) X_train = jnp.concatenate([t_grid.reshape(-1, 1), x_grid.reshape(-1, 1)], axis=1) y_train = u.reshape(-1, 1) # Instantiating model and optimizers model = Deepmod(features=[50, 50, 1]) key = random.PRNGKey(42) params = model.init(key, X_train) optimizer = optim.Adam(learning_rate=2e-3, beta1=0.99, beta2=0.99) optimizer = optimizer.create(params) # Compiling train step update = create_update(loss_fn_mse, model=model, x=X_train, y=y_train) _ = update(optimizer) # triggering compilation # Running to convergence max_epochs = 10001 t_start = time() for i in jnp.arange(max_epochs): optimizer, loss = update(optimizer) if i % 1000 == 0:
u = burgers(x_grid, t_grid, 0.1, 1.0) X = jnp.concatenate([t_grid.reshape(-1, 1), x_grid.reshape(-1, 1)], axis=1) y = u.reshape(-1, 1) y += 0.10 * jnp.std(y) * random.normal(key, y.shape) elif dataset == "kdv": key = random.PRNGKey(42) x = jnp.linspace(-10, 10, 100) t = jnp.linspace(0.1, 1.0, 20) t_grid, x_grid = jnp.meshgrid(t, x, indexing="ij") u = doublesoliton(x_grid, t_grid, c=[5.0, 2.0], x0=[0.0, -5.0]) X = jnp.concatenate([t_grid.reshape(-1, 1), x_grid.reshape(-1, 1)], axis=1) y = u.reshape(-1, 1) y += 0.10 * jnp.std(y) * random.normal(key, y.shape) else: raise NotImplementedError # %% Building model and params model = Deepmod([30, 30, 30, 1], (5, 4)) variables = model.init(key, X) optimizer = optim.Adam(learning_rate=2e-3, beta1=0.99, beta2=0.99) state, params = variables.pop("params") optimizer = optimizer.create(params) update_fn = create_update(loss_fn_pinn, (model, X, y, 1.0)) # %% optimizer, state = train_max_iter(update_fn, optimizer, state, 5000) # %%