import numpy as onp import jax.config import jax.numpy as np import jax.lax as lax import jax.api as api from mlift import construct_state_space_model_generators from mlift.distributions import half_normal, uniform from mlift.prior import PriorSpecification, set_up_prior import mlift.example_models.utils as utils jax.config.update("jax_enable_x64", True) jax.config.update("jax_platform_name", "cpu") prior_specifications = { "α": PriorSpecification(distribution=half_normal(1)), "β": PriorSpecification(distribution=uniform(0, 1)), "σ": PriorSpecification(distribution=half_normal(1)), } compute_dim_u, generate_params, prior_neg_log_dens, sample_from_prior = set_up_prior( prior_specifications) def generate_x_0(params, v_0, data): return v_0 def forward_func(params, v, x, data): return np.sqrt(params["α"] + params["β"] * x**2) * v
import numpy as onp import jax.config import jax.numpy as np import jax.lax as lax import jax.api as api from mlift.systems import IndependentAdditiveNoiseModelSystem from mlift.distributions import normal, uniform, half_cauchy from mlift.prior import PriorSpecification, set_up_prior import mlift.example_models.utils as utils jax.config.update("jax_enable_x64", True) jax.config.update("jax_platform_name", "cpu") prior_specifications = { "μ": PriorSpecification(distribution=normal(0, 10)), "ϕ": PriorSpecification(distribution=uniform(-1, 1)), "θ": PriorSpecification(distribution=uniform(-1, 1)), "σ": PriorSpecification(distribution=half_cauchy(2.5)), } compute_dim_u, generate_params, prior_neg_log_dens, sample_from_prior = set_up_prior( prior_specifications) def generate_from_model(u, data): params = generate_params(u, data) def step(x, y): x = params["μ"] + params["ϕ"] * y + params["θ"] * (y - x) return x, x
import jax.config import jax.numpy as np import jax.lax as lax import jax.api as api from mlift.systems import IndependentAdditiveNoiseModelSystem from mlift.distributions import normal, uniform, half_cauchy from mlift.prior import PriorSpecification, set_up_prior import mlift.example_models.utils as utils jax.config.update("jax_enable_x64", True) jax.config.update("jax_platform_name", "cpu") prior_specifications = { "μ": PriorSpecification(distribution=normal(0, 10)), "α_0": PriorSpecification(distribution=half_cauchy(2.5)), "α_1": PriorSpecification(distribution=uniform(0, 1)), # β_1 ~ uniform(0, 1 - α_1) therefore β_1 / (1 - α_1) ~ uniform(0, 1) "β_1_over_1_minus_α_1": PriorSpecification(distribution=uniform(0, 1)), } compute_dim_u, generate_params, prior_neg_log_dens, sample_from_prior = set_up_prior( prior_specifications) def generate_from_model(u, data): params = generate_params(u, data) params["β_1"] = params.pop("β_1_over_1_minus_α_1") * (1 - params["α_1"]) def step(x, y): x = params["α_0"] + params["α_1"] * ( y - params["μ"])**2 + params["β_1"] * x