Ejemplo n.º 1
0
        def calc_p_single(y, theta, mu_single, sigma_single, sigma_multiple, mu_multiple_scalar):

            with warnings.catch_warnings(): 
                # I'll log whatever number I want python you can't tell me what to do
                warnings.simplefilter("ignore") 

                mu_multiple = np.log(mu_single + mu_multiple_scalar * sigma_single) + sigma_multiple**2
                
                ln_s = np.log(theta) + utils.normal_lpdf(y, mu_single, sigma_single)
                ln_m = np.log(1-theta) + utils.lognormal_lpdf(y, mu_multiple, sigma_multiple)

                # FIX BAD SUPPORT.

                # This is a BAD MAGIC HACK where we are just going to flip things.
                """
                limit = mu_single - 2 * sigma_single
                bad_support = (y <= limit) * (ln_m > ln_s)
                ln_s_bs = np.copy(ln_s[bad_support])
                ln_m_bs = np.copy(ln_m[bad_support])
                ln_s[bad_support] = ln_m_bs
                ln_m[bad_support] = ln_s_bs
                """
                ln_s = np.atleast_1d(ln_s)
                ln_m = np.atleast_1d(ln_m)

                lp = np.array([ln_s, ln_m]).T

                #assert np.all(np.isfinite(lp))

                p_single = np.exp(lp[:, 0] - special.logsumexp(lp, axis=1))

            return (p_single, ln_s, ln_m)
Ejemplo n.º 2
0
def calc_p_single(y,
                  theta,
                  mu_single,
                  sigma_single,
                  sigma_multiple,
                  mu_multiple_scalar,
                  N=100):

    #check_support(theta, mu_single, sigma_single, sigma_multiple, mu_multiple_scalar, M=M, N=N)

    with warnings.catch_warnings():
        # I'll log whatever number I want python you can't tell me what to do
        warnings.simplefilter("ignore")

        mu_multiple = np.log(mu_single + mu_multiple_scalar *
                             sigma_single) + sigma_multiple**2

        sigmoid_weight = get_sigmoid_weight(sigma_single)
        sigmoid = 1 / (1 + np.exp(-sigmoid_weight * (y - mu_single)))

        ln_s = np.log(theta) + utils.normal_lpdf(y, mu_single, sigma_single)
        ln_m = np.log(1 - theta) + utils.lognormal_lpdf(
            y, mu_multiple, sigma_multiple)

        # Add sigmoid
        ln_m += np.log(sigmoid)

        ln_s = np.atleast_1d(ln_s)
        ln_m = np.atleast_1d(ln_m)

        lp = np.array([ln_s, ln_m]).T

        p_single = np.exp(lp[:, 0] - special.logsumexp(lp, axis=1))

    return (p_single, ln_s, ln_m)
Ejemplo n.º 3
0
            # Do in reverse order because the last edge cases are more problematic.
            i = -(ii + 1)

            # Check prior.
            if not np.isfinite(
                    ln_prior(theta, mu_single, sigma_single, sigma_multiple)):
                # Don't evaluate support at places outside our prior space.
                continue

            mu_multiple = get_mu_multiple(mu_single, sigma_single,
                                          sigma_multiple, mu_multiple_scalar)

            ln_s = np.log(theta) + utils.normal_lpdf(xi, mu_single,
                                                     sigma_single)
            ln_m = np.log(1 - theta) + utils.lognormal_lpdf(
                xi, mu_multiple, sigma_multiple)

            # Add sigmoid
            sigmoid_weight = (1.0 / sigma_single) * np.log(
                (2 * np.pi * sigma_single)**0.5 * np.exp(0.5 * M**2) - 1)
            sigmoid = 1 / (1 + np.exp(-sigmoid_weight * (xi - mu_single)))

            ln_m = np.log(np.exp(ln_m) * sigmoid)

            def plot_it():
                fig, axes = plt.subplots(2, sharex=True)
                axes[0].plot(xi, ln_s, c="tab:blue")
                axes[0].plot(xi, ln_m, c="tab:red")

                axes[1].plot(xi, np.exp(ln_s), c="tab:blue")
                axes[1].plot(xi, np.exp(ln_m), c="tab:red")
Ejemplo n.º 4
0
def check_support(theta,
                  mu_single,
                  sigma_single,
                  sigma_multiple,
                  mu_multiple_scalar,
                  M=2,
                  N=1000,
                  max_sigma_single_away=10):

    max_x = np.max(mu_single + max_sigma_single_away * sigma_single)
    epsilon = 0.01
    x = np.atleast_2d(np.linspace(epsilon, max_x, N)).T

    # Chose some index.
    theta = np.atleast_2d(theta)
    mu_single = np.atleast_2d(mu_single)
    sigma_single = np.atleast_2d(sigma_single)
    sigma_multiple = np.atleast_2d(sigma_multiple)
    mu_multiple = np.log(mu_single +
                         mu_multiple_scalar * sigma_single) + sigma_multiple**2

    with warnings.catch_warnings():
        # I'll log whatever number I want python you can't tell me what to do
        warnings.simplefilter("ignore")

        sigmoid_weight = get_sigmoid_weight(sigma_single, M=M)
        sigmoid = 1 / (1 + np.exp(-sigmoid_weight * (x - mu_single)))

        ln_s = np.log(theta) + utils.normal_lpdf(x, mu_single, sigma_single)
        ln_m = np.log(1 - theta) + utils.lognormal_lpdf(
            x, mu_multiple, sigma_multiple)

        # Add sigmoid
        ln_m_truncated = np.log(np.exp(ln_m) * sigmoid)

        # Check left hand side.
        for i in range(theta.size):
            try:
                # check ln_single is more than truncated on the LHS
                j = x[:, 0].searchsorted(mu_single[0, i])
                assert np.all(ln_s[:j, i] > ln_m_truncated[:j, i])

                # check that once the ln_m is preferred, that it is always preferred on the RHS
                j = np.where(ln_m_truncated[:, i] > ln_s[:, i])[0][0]

                assert np.all(ln_m_truncated[:, i][j:] > ln_s[:, i][j:])

            except (AssertionError, IndexError):

                fig, axes = plt.subplots(2)
                axes[0].plot(x, ln_s[:, i], c="tab:blue")
                axes[0].plot(x, ln_m[:, i], c="tab:red")
                axes[0].plot(x, ln_m_truncated[:, i], c="k")

                axes[1].plot(x, np.exp(ln_s[:, i]), c="tab:blue")
                axes[1].plot(x, np.exp(ln_m[:, i]), c="tab:red")
                axes[1].plot(x, np.exp(ln_m_truncated[:, i]), c="k")

                ln_m2 = np.log(1 - theta) + utils.lognormal_lpdf(
                    x, mu_multiple, 2 * sigma_multiple)

                # Add sigmoid
                ln_m_truncated2 = np.log(np.exp(ln_m2) * sigmoid)

                axes[0].plot(x, ln_m_truncated2[:, i], c="g")
                axes[1].plot(x, np.exp(ln_m_truncated2[:, i]), c="g")

                raise
        """
        index = np.random.choice(N)

        fig, axes = plt.subplots(2)
        axes[0].plot(x, ln_s[:, index], c="tab:blue")
        axes[0].plot(x, ln_m[:, index], c="tab:red")
        axes[0].plot(x, ln_m_truncated[:, index], c="k")

        axes[1].plot(x, np.exp(ln_s[:, index]), c="tab:blue")
        axes[1].plot(x, np.exp(ln_m[:, index]), c="tab:red")
        axes[1].plot(x, np.exp(ln_m_truncated[:, index]), c="k")
        """

        return True
Ejemplo n.º 5
0
        ln_m = g.create_dataset("multiple", data=np.nan * np.ones(N))
        sigma_m, sigma_m_var = results[
            f"{model_name}/gp_predictions/sigma_multiple"][()].T

        if "mu_multiple" in results[f"{model_name}/gp_predictions"]:
            mu_m = results[f"{model_name}/gp_predictions/mu_multiple"][()][:,
                                                                           0]

            # TODO: draws in mu_m

        else:
            scalar = model_config["mu_multiple_scalar"]
            mu_m = np.log(mu_s + scalar * sigma_s) + sigma_m**2

        ln_m[:] = np.log(1 - w) + utils.lognormal_lpdf(y, mu_m, sigma_m)

        # If y < mu_s then it is clearly a single star, but sometimes the log-normal can have a
        # little bit more support.

        print("ANDY FIX SUPPORT PROBLEM")

        lp = np.array([ln_s[()], ln_m[()]]).T

        with np.errstate(under="ignore"):
            ratio = np.exp(lp[:, 0] - special.logsumexp(lp, axis=1))

        full_ratios = np.zeros(sources["source_id"].size, dtype=float)
        full_ratios[data_indices] = ratio
        g.create_dataset("ratio_single", data=full_ratios)