def reparametrize_to_standard_normal(prior_spec): from_standard_normal_transform = ( prior_spec.distribution.from_standard_normal_transform) if prior_spec.transform is not None: transform = lambda u: prior_spec.transform( from_standard_normal_transform(u)) else: transform = from_standard_normal_transform return PriorSpecification(shape=prior_spec.shape, distribution=normal(0, 1), transform=transform)
from mlift.ode import integrate_ode_rk4 from mlift.prior import PriorSpecification, set_up_prior import mlift.example_models.utils as utils jax.config.update("jax_enable_x64", True) jax.config.update("jax_platform_name", "cpu") prior_specifications = { "α": PriorSpecification(distribution=log_normal(0, 1)), "β": PriorSpecification(distribution=log_normal(0, 1)), "γ": PriorSpecification(distribution=log_normal(0, 1)), "δ": PriorSpecification(distribution=log_normal(-1, 1)), "ϵ": PriorSpecification(distribution=log_normal(-3, 1)), "ζ": PriorSpecification(distribution=log_normal(-2, 1)), "σ": PriorSpecification(distribution=log_normal(-1, 1)), "x_init": PriorSpecification(shape=(2, ), distribution=normal(0, 1)), } compute_dim_u, generate_params, prior_neg_log_dens, sample_from_prior = set_up_prior( prior_specifications) def dx_dt(x, t, params): return np.array(( params["α"] * x[0] - params["β"] * x[0]**3 + params["γ"] * x[1], -params["δ"] * x[0] - params["ϵ"] * x[1] + params["ζ"], )) def observation_func(x): return x[1:, 0]
import os import numpy as onp import jax.config import jax.numpy as np import jax.lax as lax import jax.api as api from mlift.systems import HierarchicalLatentVariableModelSystem from mlift.distributions import normal, half_cauchy from mlift.prior import PriorSpecification, set_up_prior import mlift.example_models.utils as utils jax.config.update("jax_enable_x64", True) jax.config.update("jax_platform_name", "cpu") prior_specifications = { "μ": PriorSpecification(distribution=normal(0, 5)), "τ": PriorSpecification(distribution=half_cauchy(5)), } compute_dim_u, generate_params, prior_neg_log_dens, sample_from_prior = set_up_prior( prior_specifications) def generate_from_model(u, v, data): params = generate_params(u, data) x = params["μ"] + params["τ"] * v return params, x def generate_y(u, v, n, data): _, x = generate_from_model(u, v, data)
import numpy as onp import jax.config import jax.numpy as np import jax.lax as lax import jax.api as api from mlift.systems import IndependentAdditiveNoiseModelSystem from mlift.distributions import normal, uniform, half_cauchy from mlift.prior import PriorSpecification, set_up_prior import mlift.example_models.utils as utils jax.config.update("jax_enable_x64", True) jax.config.update("jax_platform_name", "cpu") prior_specifications = { "α": PriorSpecification(distribution=normal(0, 10)), "β": PriorSpecification(shape=lambda data: data["max_lag"], distribution=normal(0, 10)), "σ": PriorSpecification(distribution=half_cauchy(2.5)), } compute_dim_u, generate_params, prior_neg_log_dens, sample_from_prior = set_up_prior( prior_specifications) def generate_from_model(u, data): params = generate_params(u, data) x = params["α"] + (data["y_windows"] * params["β"]).sum(-1) return params, x
from mlift.systems import IndependentAdditiveNoiseModelSystem from mlift.distributions import normal, log_normal from mlift.prior import PriorSpecification, set_up_prior import mlift.example_models.utils as utils jax.config.update("jax_enable_x64", True) jax.config.update("jax_platform_name", "cpu") prior_specifications = { "k_tau_p_1": PriorSpecification(distribution=log_normal(8, 1)), "g_bar_Na": PriorSpecification(distribution=log_normal(4, 1)), "g_bar_K": PriorSpecification(distribution=log_normal(2, 1)), "g_bar_M": PriorSpecification(distribution=log_normal(-3, 1)), "g_leak": PriorSpecification(distribution=log_normal(-3, 1)), "v_t": PriorSpecification(distribution=normal(-60, 10)), "σ": PriorSpecification(distribution=log_normal(0, 1)), } compute_dim_u, generate_params, prior_neg_log_dens, sample_from_prior = set_up_prior( prior_specifications ) def x_over_expm1_x(x): return x / np.expm1(x) def alpha_n(v, params_and_data): return ( params_and_data["k_alpha_n_1"]
from collections import namedtuple import numpy as np from mlift.distributions import normal, pullback_distribution from mlift.transforms import ( unbounded_to_lower_bounded, unbounded_to_upper_bounded, unbounded_to_lower_and_upper_bounded, ) PriorSpecification = namedtuple( "PriorSpecification", ("shape", "distribution", "transform"), defaults=((), normal(0, 1), None), ) def reparametrize_to_unbounded_support(prior_spec): if (prior_spec.distribution.support.lower != -np.inf and prior_spec.distribution.support.upper != np.inf): bounding_transform = unbounded_to_lower_and_upper_bounded( prior_spec.distribution.support.lower, prior_spec.distribution.support.upper) elif prior_spec.distribution.support.lower != -np.inf: bounding_transform = unbounded_to_lower_bounded( prior_spec.distribution.support.lower) elif prior_spec.distribution.support.upper != np.inf: bounding_transform = unbounded_to_upper_bounded( prior_spec.distribution.support.upper) else: return prior_spec distribution = pullback_distribution(prior_spec.distribution,