Example #1
0
    def to_minimize(flat_theta):

        theta = reconstruct(flat_theta, summary, jnp.reshape)
        theta = apply_transformation(theta, "log_", jnp.exp, "")

        lik = calculate_likelihood(
            theta,
            species_ids,
            fg_covs,
            bg_covs,
            fg_covs_thin,
            bg_covs_thin,
            quad_weights,
            counts,
            n_s,
            n_fg,
        )
        kl = calculate_kl(theta)

        prior = jnp.sum(
            gamma.logpdf(theta["w_prior_var"], 0.5, scale=1.0 / n_c))
        prior = prior + jnp.sum(
            norm.logpdf(theta["w_prior_mean"], 0.0, scale=jnp.sqrt(1.0 / n_c)))

        return -(lik - kl + prior)
Example #2
0
    def annotated_with_grad(flat_theta, summary):

        flat_theta = jnp.array(flat_theta)

        obj, grad = with_grad(flat_theta)

        print(obj, jnp.linalg.norm(grad))

        if jnp.isnan(obj) or jnp.isinf(obj) or jnp.any(jnp.isnan(grad)):
            import ipdb

            problem = reconstruct(flat_theta, summary, jnp.reshape)

            ipdb.set_trace()

        return np.array(obj).astype(np.float64), np.array(grad).astype(
            np.float64)
Example #3
0
    def to_minimize(flat_theta):

        theta = reconstruct(flat_theta, summary, jnp.reshape)

        kern_fn = get_kernel_fun(kernel_currier, theta, transformation_fun)

        spec = sv.SVGPSpec(
            m=theta["mu"], L_elts=theta["L_elts"], Z=theta["Z"], kern_fn=kern_fn
        )

        pred_mean, pred_var = sv.project_to_x(spec, X)

        kl = sv.calculate_kl(spec)

        lik = jnp.sum(expectation_1d(likelihood_fun, pred_mean, pred_var))

        prior = prior_fun(transformation_fun(theta))

        return -(lik - kl + prior)
Example #4
0
def fit(
    fg_covs,
    bg_covs,
    species_ids,
    quad_weights,
    counts,
    fg_covs_thin=None,
    bg_covs_thin=None,
):

    n_c = fg_covs.shape[1]
    n_s = len(np.unique(species_ids))
    n_c_thin = 0 if fg_covs_thin is None else fg_covs_thin.shape[1]
    n_fg = fg_covs.shape[0]
    n_bg = bg_covs.shape[0]

    fg_covs_thin = fg_covs_thin if fg_covs_thin is not None else jnp.zeros(
        (n_fg, 0))
    bg_covs_thin = bg_covs_thin if bg_covs_thin is not None else jnp.zeros(
        (n_bg, 0))

    init_theta = {
        "w_means":
        jnp.zeros((n_c, n_s)),
        "log_w_vars":
        jnp.log(jnp.tile(1.0 / n_c, (n_c, n_s))) - 5,
        "intercept_means":
        jnp.zeros(n_s) - 5,
        "log_intercept_vars":
        jnp.zeros(n_s) - 5,
        "w_prior_mean":
        jnp.zeros((n_c, 1)),
        "log_w_prior_var":
        jnp.log(jnp.tile(1.0 / n_c, (n_c, 1))) - 5,
        # Thinning is assumed constant across species
        "w_means_thin":
        jnp.zeros((n_c_thin, 1)),
        "log_w_vars_thin":
        jnp.zeros((n_c_thin, 1)) - 10 if n_c_thin == 0 else
        jnp.log(jnp.tile(1.0 / (n_c_thin), (n_c_thin, 1))) - 10,
        "log_w_prior_var_thin":
        jnp.array(0.0) if n_c_thin == 0 else jnp.log(1.0 / (n_c_thin)) - 10,
    }

    flat_theta, summary = flatten_and_summarise(**init_theta)

    def to_minimize(flat_theta):

        theta = reconstruct(flat_theta, summary, jnp.reshape)
        theta = apply_transformation(theta, "log_", jnp.exp, "")

        lik = calculate_likelihood(
            theta,
            species_ids,
            fg_covs,
            bg_covs,
            fg_covs_thin,
            bg_covs_thin,
            quad_weights,
            counts,
            n_s,
            n_fg,
        )
        kl = calculate_kl(theta)

        prior = jnp.sum(
            gamma.logpdf(theta["w_prior_var"], 0.5, scale=1.0 / n_c))
        prior = prior + jnp.sum(
            norm.logpdf(theta["w_prior_mean"], 0.0, scale=jnp.sqrt(1.0 / n_c)))

        return -(lik - kl + prior)

    with_grad = jit(value_and_grad(to_minimize))

    def annotated_with_grad(flat_theta, summary):

        flat_theta = jnp.array(flat_theta)

        obj, grad = with_grad(flat_theta)

        print(obj, jnp.linalg.norm(grad))

        if jnp.isnan(obj) or jnp.isinf(obj) or jnp.any(jnp.isnan(grad)):
            import ipdb

            problem = reconstruct(flat_theta, summary, jnp.reshape)

            ipdb.set_trace()

        return np.array(obj).astype(np.float64), np.array(grad).astype(
            np.float64)

    result = minimize(
        partial(annotated_with_grad, summary=summary),
        flat_theta,
        method="L-BFGS-B",
        jac=True,
    )
    final_theta = reconstruct(result.x, summary, jnp.reshape)
    final_theta = apply_transformation(final_theta, "log_", jnp.exp, "")

    return final_theta
Example #5
0
def fit(
    X,
    init_kernel_params,
    kernel_currier,
    likelihood_fun,
    prior_fun,
    transformation_fun=constrain_positive,
    n_inducing=100,
    verbose=False,
    Z=None,
):

    if Z is None:
        Z = find_starting_z(X, n_inducing)

    init_kern_fn = get_kernel_fun(
        kernel_currier, init_kernel_params, transformation_fun
    )

    init_spec = sv.initialise_using_kernel_fun(init_kern_fn, Z)

    theta = {
        "mu": init_spec.m,
        "L_elts": init_spec.L_elts,
        "Z": jnp.array(Z),
        **init_kernel_params,
    }

    flat_theta, summary = flatten_and_summarise(**theta)

    def to_minimize(flat_theta):

        theta = reconstruct(flat_theta, summary, jnp.reshape)

        kern_fn = get_kernel_fun(kernel_currier, theta, transformation_fun)

        spec = sv.SVGPSpec(
            m=theta["mu"], L_elts=theta["L_elts"], Z=theta["Z"], kern_fn=kern_fn
        )

        pred_mean, pred_var = sv.project_to_x(spec, X)

        kl = sv.calculate_kl(spec)

        lik = jnp.sum(expectation_1d(likelihood_fun, pred_mean, pred_var))

        prior = prior_fun(transformation_fun(theta))

        return -(lik - kl + prior)

    with_grad = partial(convert_decorator, verbose=verbose)(
        jit(value_and_grad(to_minimize))
    )

    result = minimize(with_grad, flat_theta, method="L-BFGS-B", jac=True)

    final_theta = reconstruct(result.x, summary, jnp.reshape)

    kern = get_kernel_fun(kernel_currier, final_theta, transformation_fun)

    final_spec = sv.SVGPSpec(
        m=final_theta["mu"],
        L_elts=final_theta["L_elts"],
        Z=final_theta["Z"],
        kern_fn=kern,
    )

    return final_spec, transformation_fun(final_theta)