Example #1
0
def reparametrize_to_standard_normal(prior_spec):
    from_standard_normal_transform = (
        prior_spec.distribution.from_standard_normal_transform)
    if prior_spec.transform is not None:
        transform = lambda u: prior_spec.transform(
            from_standard_normal_transform(u))
    else:
        transform = from_standard_normal_transform
    return PriorSpecification(shape=prior_spec.shape,
                              distribution=normal(0, 1),
                              transform=transform)
from mlift.ode import integrate_ode_rk4
from mlift.prior import PriorSpecification, set_up_prior
import mlift.example_models.utils as utils

jax.config.update("jax_enable_x64", True)
jax.config.update("jax_platform_name", "cpu")

prior_specifications = {
    "α": PriorSpecification(distribution=log_normal(0, 1)),
    "β": PriorSpecification(distribution=log_normal(0, 1)),
    "γ": PriorSpecification(distribution=log_normal(0, 1)),
    "δ": PriorSpecification(distribution=log_normal(-1, 1)),
    "ϵ": PriorSpecification(distribution=log_normal(-3, 1)),
    "ζ": PriorSpecification(distribution=log_normal(-2, 1)),
    "σ": PriorSpecification(distribution=log_normal(-1, 1)),
    "x_init": PriorSpecification(shape=(2, ), distribution=normal(0, 1)),
}

compute_dim_u, generate_params, prior_neg_log_dens, sample_from_prior = set_up_prior(
    prior_specifications)


def dx_dt(x, t, params):
    return np.array((
        params["α"] * x[0] - params["β"] * x[0]**3 + params["γ"] * x[1],
        -params["δ"] * x[0] - params["ϵ"] * x[1] + params["ζ"],
    ))


def observation_func(x):
    return x[1:, 0]
import os
import numpy as onp
import jax.config
import jax.numpy as np
import jax.lax as lax
import jax.api as api
from mlift.systems import HierarchicalLatentVariableModelSystem
from mlift.distributions import normal, half_cauchy
from mlift.prior import PriorSpecification, set_up_prior
import mlift.example_models.utils as utils

jax.config.update("jax_enable_x64", True)
jax.config.update("jax_platform_name", "cpu")

prior_specifications = {
    "μ": PriorSpecification(distribution=normal(0, 5)),
    "τ": PriorSpecification(distribution=half_cauchy(5)),
}

compute_dim_u, generate_params, prior_neg_log_dens, sample_from_prior = set_up_prior(
    prior_specifications)


def generate_from_model(u, v, data):
    params = generate_params(u, data)
    x = params["μ"] + params["τ"] * v
    return params, x


def generate_y(u, v, n, data):
    _, x = generate_from_model(u, v, data)
Example #4
0
import numpy as onp
import jax.config
import jax.numpy as np
import jax.lax as lax
import jax.api as api
from mlift.systems import IndependentAdditiveNoiseModelSystem
from mlift.distributions import normal, uniform, half_cauchy
from mlift.prior import PriorSpecification, set_up_prior
import mlift.example_models.utils as utils

jax.config.update("jax_enable_x64", True)
jax.config.update("jax_platform_name", "cpu")

prior_specifications = {
    "α":
    PriorSpecification(distribution=normal(0, 10)),
    "β":
    PriorSpecification(shape=lambda data: data["max_lag"],
                       distribution=normal(0, 10)),
    "σ":
    PriorSpecification(distribution=half_cauchy(2.5)),
}

compute_dim_u, generate_params, prior_neg_log_dens, sample_from_prior = set_up_prior(
    prior_specifications)


def generate_from_model(u, data):
    params = generate_params(u, data)
    x = params["α"] + (data["y_windows"] * params["β"]).sum(-1)
    return params, x
from mlift.systems import IndependentAdditiveNoiseModelSystem
from mlift.distributions import normal, log_normal
from mlift.prior import PriorSpecification, set_up_prior
import mlift.example_models.utils as utils

jax.config.update("jax_enable_x64", True)
jax.config.update("jax_platform_name", "cpu")


prior_specifications = {
    "k_tau_p_1": PriorSpecification(distribution=log_normal(8, 1)),
    "g_bar_Na": PriorSpecification(distribution=log_normal(4, 1)),
    "g_bar_K": PriorSpecification(distribution=log_normal(2, 1)),
    "g_bar_M": PriorSpecification(distribution=log_normal(-3, 1)),
    "g_leak": PriorSpecification(distribution=log_normal(-3, 1)),
    "v_t": PriorSpecification(distribution=normal(-60, 10)),
    "σ": PriorSpecification(distribution=log_normal(0, 1)),
}

compute_dim_u, generate_params, prior_neg_log_dens, sample_from_prior = set_up_prior(
    prior_specifications
)


def x_over_expm1_x(x):
    return x / np.expm1(x)


def alpha_n(v, params_and_data):
    return (
        params_and_data["k_alpha_n_1"]
Example #6
0
from collections import namedtuple
import numpy as np
from mlift.distributions import normal, pullback_distribution
from mlift.transforms import (
    unbounded_to_lower_bounded,
    unbounded_to_upper_bounded,
    unbounded_to_lower_and_upper_bounded,
)

PriorSpecification = namedtuple(
    "PriorSpecification",
    ("shape", "distribution", "transform"),
    defaults=((), normal(0, 1), None),
)


def reparametrize_to_unbounded_support(prior_spec):
    if (prior_spec.distribution.support.lower != -np.inf
            and prior_spec.distribution.support.upper != np.inf):
        bounding_transform = unbounded_to_lower_and_upper_bounded(
            prior_spec.distribution.support.lower,
            prior_spec.distribution.support.upper)
    elif prior_spec.distribution.support.lower != -np.inf:
        bounding_transform = unbounded_to_lower_bounded(
            prior_spec.distribution.support.lower)
    elif prior_spec.distribution.support.upper != np.inf:
        bounding_transform = unbounded_to_upper_bounded(
            prior_spec.distribution.support.upper)
    else:
        return prior_spec
    distribution = pullback_distribution(prior_spec.distribution,