Example #1
0
import numpy as onp
import jax.config
import jax.numpy as np
import jax.lax as lax
import jax.api as api
from mlift import construct_state_space_model_generators
from mlift.distributions import half_normal, uniform
from mlift.prior import PriorSpecification, set_up_prior
import mlift.example_models.utils as utils

jax.config.update("jax_enable_x64", True)
jax.config.update("jax_platform_name", "cpu")

prior_specifications = {
    "α": PriorSpecification(distribution=half_normal(1)),
    "β": PriorSpecification(distribution=uniform(0, 1)),
    "σ": PriorSpecification(distribution=half_normal(1)),
}

compute_dim_u, generate_params, prior_neg_log_dens, sample_from_prior = set_up_prior(
    prior_specifications)


def generate_x_0(params, v_0, data):
    return v_0


def forward_func(params, v, x, data):
    return np.sqrt(params["α"] + params["β"] * x**2) * v

Example #2
0
import numpy as onp
import jax.config
import jax.numpy as np
import jax.lax as lax
import jax.api as api
from mlift.systems import IndependentAdditiveNoiseModelSystem
from mlift.distributions import normal, uniform, half_cauchy
from mlift.prior import PriorSpecification, set_up_prior
import mlift.example_models.utils as utils

jax.config.update("jax_enable_x64", True)
jax.config.update("jax_platform_name", "cpu")

prior_specifications = {
    "μ": PriorSpecification(distribution=normal(0, 10)),
    "ϕ": PriorSpecification(distribution=uniform(-1, 1)),
    "θ": PriorSpecification(distribution=uniform(-1, 1)),
    "σ": PriorSpecification(distribution=half_cauchy(2.5)),
}

compute_dim_u, generate_params, prior_neg_log_dens, sample_from_prior = set_up_prior(
    prior_specifications)


def generate_from_model(u, data):
    params = generate_params(u, data)

    def step(x, y):
        x = params["μ"] + params["ϕ"] * y + params["θ"] * (y - x)
        return x, x
Example #3
0
import jax.config
import jax.numpy as np
import jax.lax as lax
import jax.api as api
from mlift.systems import IndependentAdditiveNoiseModelSystem
from mlift.distributions import normal, uniform, half_cauchy
from mlift.prior import PriorSpecification, set_up_prior
import mlift.example_models.utils as utils

jax.config.update("jax_enable_x64", True)
jax.config.update("jax_platform_name", "cpu")

prior_specifications = {
    "μ": PriorSpecification(distribution=normal(0, 10)),
    "α_0": PriorSpecification(distribution=half_cauchy(2.5)),
    "α_1": PriorSpecification(distribution=uniform(0, 1)),
    # β_1 ~ uniform(0, 1 - α_1) therefore β_1 / (1 - α_1) ~ uniform(0, 1)
    "β_1_over_1_minus_α_1": PriorSpecification(distribution=uniform(0, 1)),
}

compute_dim_u, generate_params, prior_neg_log_dens, sample_from_prior = set_up_prior(
    prior_specifications)


def generate_from_model(u, data):
    params = generate_params(u, data)
    params["β_1"] = params.pop("β_1_over_1_minus_α_1") * (1 - params["α_1"])

    def step(x, y):
        x = params["α_0"] + params["α_1"] * (
            y - params["μ"])**2 + params["β_1"] * x