import numpy as onp
import jax.config
import jax.numpy as np
import jax.lax as lax
import jax.api as api
from mlift.systems import IndependentAdditiveNoiseModelSystem
from mlift.distributions import half_normal, half_cauchy, beta
from mlift.ode import integrate_ode_rk4
from mlift.prior import PriorSpecification, set_up_prior
import mlift.example_models.utils as utils

jax.config.update("jax_enable_x64", True)
jax.config.update("jax_platform_name", "cpu")

prior_specifications = {
    "k_1": PriorSpecification(distribution=half_normal(1)),
    "k_2": PriorSpecification(distribution=half_normal(1)),
    "α_21": PriorSpecification(distribution=half_normal(1)),
    "α_12": PriorSpecification(distribution=half_normal(1)),
    "γ": PriorSpecification(distribution=beta(10, 1)),
    "σ": PriorSpecification(distribution=half_cauchy(1)),
}

compute_dim_u, generate_params, prior_neg_log_dens, sample_from_prior = set_up_prior(
    prior_specifications)


def dx_dt(x, t, params):
    return np.array((
        -params["k_1"] * x[0] + params["α_12"] * params["k_2"] * x[1],
        -params["k_2"] * x[1] + params["α_21"] * params["k_1"] * x[0],
import jax.config
import jax.numpy as np
import jax.lax as lax
import jax.api as api
from mlift.systems import GeneralGaussianProcessModelSystem
from mlift.transforms import standard_normal_to_students_t
from mlift.distributions import half_normal, inverse_gamma, students_t
from mlift.prior import PriorSpecification, set_up_prior
import mlift.example_models.utils as utils

jax.config.update("jax_enable_x64", True)
jax.config.update("jax_platform_name", "cpu")

prior_specifications = {
    "α":
    PriorSpecification(distribution=half_normal(3)),
    "λ":
    PriorSpecification(shape=lambda data: data["x"].shape[1],
                       distribution=inverse_gamma(4, 10)),
    "σ":
    PriorSpecification(distribution=half_normal(1)),
}

compute_dim_u, generate_params, prior_neg_log_dens, sample_from_prior = set_up_prior(
    prior_specifications)


def squared_exp_covar(x, params):
    def sq_exp(x1, x2):
        return np.exp(-(((x1 - x2) / params["λ"])**2).sum() / 2)
Пример #3
0
import os
import numpy as onp
import jax.config
import jax.numpy as np
import jax.lax as lax
import jax.api as api
from mlift import construct_state_space_model_generators
from mlift.distributions import half_normal, uniform
from mlift.prior import PriorSpecification, set_up_prior
import mlift.example_models.utils as utils

jax.config.update("jax_enable_x64", True)
jax.config.update("jax_platform_name", "cpu")

prior_specifications = {
    "α": PriorSpecification(distribution=half_normal(1)),
    "β": PriorSpecification(distribution=uniform(0, 1)),
    "σ": PriorSpecification(distribution=half_normal(1)),
}

compute_dim_u, generate_params, prior_neg_log_dens, sample_from_prior = set_up_prior(
    prior_specifications)


def generate_x_0(params, v_0, data):
    return v_0


def forward_func(params, v, x, data):
    return np.sqrt(params["α"] + params["β"] * x**2) * v