import numpy as onp import jax.config import jax.numpy as np import jax.lax as lax import jax.api as api from mlift.systems import IndependentAdditiveNoiseModelSystem from mlift.distributions import half_normal, half_cauchy, beta from mlift.ode import integrate_ode_rk4 from mlift.prior import PriorSpecification, set_up_prior import mlift.example_models.utils as utils jax.config.update("jax_enable_x64", True) jax.config.update("jax_platform_name", "cpu") prior_specifications = { "k_1": PriorSpecification(distribution=half_normal(1)), "k_2": PriorSpecification(distribution=half_normal(1)), "α_21": PriorSpecification(distribution=half_normal(1)), "α_12": PriorSpecification(distribution=half_normal(1)), "γ": PriorSpecification(distribution=beta(10, 1)), "σ": PriorSpecification(distribution=half_cauchy(1)), } compute_dim_u, generate_params, prior_neg_log_dens, sample_from_prior = set_up_prior( prior_specifications) def dx_dt(x, t, params): return np.array(( -params["k_1"] * x[0] + params["α_12"] * params["k_2"] * x[1], -params["k_2"] * x[1] + params["α_21"] * params["k_1"] * x[0],
import jax.config import jax.numpy as np import jax.lax as lax import jax.api as api from mlift.systems import GeneralGaussianProcessModelSystem from mlift.transforms import standard_normal_to_students_t from mlift.distributions import half_normal, inverse_gamma, students_t from mlift.prior import PriorSpecification, set_up_prior import mlift.example_models.utils as utils jax.config.update("jax_enable_x64", True) jax.config.update("jax_platform_name", "cpu") prior_specifications = { "α": PriorSpecification(distribution=half_normal(3)), "λ": PriorSpecification(shape=lambda data: data["x"].shape[1], distribution=inverse_gamma(4, 10)), "σ": PriorSpecification(distribution=half_normal(1)), } compute_dim_u, generate_params, prior_neg_log_dens, sample_from_prior = set_up_prior( prior_specifications) def squared_exp_covar(x, params): def sq_exp(x1, x2): return np.exp(-(((x1 - x2) / params["λ"])**2).sum() / 2)
import os import numpy as onp import jax.config import jax.numpy as np import jax.lax as lax import jax.api as api from mlift import construct_state_space_model_generators from mlift.distributions import half_normal, uniform from mlift.prior import PriorSpecification, set_up_prior import mlift.example_models.utils as utils jax.config.update("jax_enable_x64", True) jax.config.update("jax_platform_name", "cpu") prior_specifications = { "α": PriorSpecification(distribution=half_normal(1)), "β": PriorSpecification(distribution=uniform(0, 1)), "σ": PriorSpecification(distribution=half_normal(1)), } compute_dim_u, generate_params, prior_neg_log_dens, sample_from_prior = set_up_prior( prior_specifications) def generate_x_0(params, v_0, data): return v_0 def forward_func(params, v, x, data): return np.sqrt(params["α"] + params["β"] * x**2) * v