# Making data x = jnp.linspace(-3, 4, 50) t = jnp.linspace(0.5, 5.0, 20) t_grid, x_grid = jnp.meshgrid(t, x, indexing="ij") u = burgers(x_grid, t_grid, 0.1, 1.0) X = jnp.concatenate([t_grid.reshape(-1, 1), x_grid.reshape(-1, 1)], axis=1) y = u.reshape(-1, 1) y += noise * jnp.std(y) * random.normal(key, y.shape) # Defning model and optimizers model = Deepmod([30, 30, 30, 1]) optimizer_def = optim.Adam(learning_rate=2e-3, beta1=0.99, beta2=0.99) # Running warm restart bayes update_fn = create_update(loss_fn_SBL, (model, X, y, True)) for run_idx, subkey in enumerate(random.split(key, n_runs)): print(f"Starting SBL run {run_idx}") variables = model.init(subkey, X) state, params = variables.pop("params") state = (state, {"prior_init": None}) # adding prior to state optimizer = optimizer_def.create(params) train_max_iter( update_fn, optimizer, state, max_iterations, log_dir=script_dir + f"sbl_run_{run_idx}/", )
# Defning model and optimizers model = Deepmod([30, 30, 30, 1]) optimizer_def = optim.Adam(learning_rate=2e-3, beta1=0.99, beta2=0.99) # Running PINN with overcomplete library update_fn = create_update(loss_fn_pinn, (model, X, y)) for run_idx, subkey in enumerate(random.split(key, n_runs)): print(f"Starting multitask run {run_idx}") variables = model.init(subkey, X) state, params = variables.pop("params") optimizer = optimizer_def.create(params) train_max_iter( update_fn, optimizer, state, max_iterations, log_dir=script_dir + f"burgers_pinn_run_{run_idx}/", ) # Running bayes with overcomplete library update_fn = create_update(loss_fn_bayesian_ridge, (model, X, y, True)) for run_idx, subkey in enumerate(random.split(key, n_runs)): print(f"Starting bayes warm run {run_idx}") variables = model.init(subkey, X) state, params = variables.pop("params") state = (state, {"prior_init": None}) # adding prior to state optimizer = optimizer_def.create(params) train_max_iter( update_fn, optimizer,
# Defning model and optimizers model = Deepmod([30, 30, 30, 1]) optimizer_def = optim.Adam(learning_rate=2e-3, beta1=0.99, beta2=0.99) # Running multitask update_fn = create_update(loss_fn_multitask_precalc, (model, X, y)) for run_idx, subkey in enumerate(random.split(key, n_runs)): print(f"Starting multitask run {run_idx}") variables = model.init(subkey, X) state, params = variables.pop("params") optimizer = optimizer_def.create(params) train_max_iter( update_fn, optimizer, state, max_iterations, log_dir=script_dir + f"multitask_run_{run_idx}/", ) # Running bayesian multitask update_fn = create_update(loss_fn_pinn_bayes_mse_hyperprior, (model, X, y)) for run_idx, subkey in enumerate(random.split(key, n_runs)): print(f"Starting bayes run {run_idx}") variables = model.init(subkey, X) state, params = variables.pop("params") optimizer = optimizer_def.create(params) train_max_iter( update_fn, optimizer, state,
from modax.training.utils import create_update from flax import optim from modax.training import train_max_iter from modax.training.losses.bayesian_regression import loss_fn_bayesian_ridge # %% Making data key = random.PRNGKey(42) x = jnp.linspace(-3, 4, 50) t = jnp.linspace(0.5, 5.0, 20) t_grid, x_grid = jnp.meshgrid(t, x, indexing="ij") u = burgers(x_grid, t_grid, 0.1, 1.0) X = jnp.concatenate([t_grid.reshape(-1, 1), x_grid.reshape(-1, 1)], axis=1) y = u.reshape(-1, 1) y += 0.10 * jnp.std(y) * random.normal(key, y.shape) # %% Building model and params model = Deepmod([30, 30, 30, 1]) variables = model.init(key, X) optimizer = optim.Adam(learning_rate=2e-3, beta1=0.99, beta2=0.99) state, params = variables.pop("params") optimizer = optimizer.create(params) state = (state, {"prior_init": None}) # adding prior to state update_fn = create_update(loss_fn_bayesian_ridge, (model, X, y, True)) optimizer, state = train_max_iter(update_fn, optimizer, state, 10000)
# Defning model and optimizers model = Deepmod([30, 30, 30, 1]) optimizer_def = optim.Adam(learning_rate=2e-3, beta1=0.99, beta2=0.99) # Running PINN update_fn = create_update(loss_fn_pinn, (model, X, y)) for run_idx, subkey in enumerate(random.split(key, n_runs)): print(f"Starting multitask run {run_idx}") variables = model.init(subkey, X) state, params = variables.pop("params") optimizer = optimizer_def.create(params) train_max_iter( update_fn, optimizer, state, max_iterations, log_dir=script_dir + f"pinn_run_{run_idx}/", ) # Running warm restart bayes update_fn = create_update(loss_fn_bayesian_ridge, (model, X, y, True)) for run_idx, subkey in enumerate(random.split(key, n_runs)): print(f"Starting bayes warm run {run_idx}") variables = model.init(subkey, X) state, params = variables.pop("params") state = (state, {"prior_init": None}) # adding prior to state optimizer = optimizer_def.create(params) train_max_iter( update_fn, optimizer,