Exemple #1
0
def plot_posterior(data: np.ndarray):
    rng_key = random.PRNGKey(0)

    # PYMM parameters
    T = 20
    t = np.arange(T + 1)

    for Npoints in (500, 1000, 2000):
        y = np.zeros(T+1)
        for _ in range(REPEATS):
            idx = np.random.choice(len(data), size=Npoints, replace=False)
            data_sub = data[idx]

            z = sample_posterior(rng_key, multivariate_gaussian_DPMM_isotropic, data_sub, N_SAMPLES, alpha=1, sigma=0, T=T,
                # Uncomment the line below to use HMCGibbs
                    # gibbs_fn=make_multivariate_gaussian_DPMM_gibbs_fn(data_sub), gibbs_sites=['z'],
                )
            y = compute_n_clusters_distribution(z, T)

        y /= REPEATS
        plt.plot(t, y, label=f"N={Npoints}")

    plt.ylabel(r"$P(T_n=t|X_{1:N})$")
    plt.xlabel("$t$")
    plt.title("Posterior distribution of the number of clusters")

    plt.legend()
    plt.show()
def make_synthetic_experiment_sigma(sample_data,
                                    model,
                                    make_gibbs_fn,
                                    explicit_ub=None):
    rng_key = random.PRNGKey(0)

    # Sampling parameters
    Npoints = 1000
    data = sample_data(rng_key, Npoints)

    # DPMM/PYMM parameters
    T = 20
    t = np.arange(T + 1)

    for sigma in [0, 0.1, .5]:
        z = sample_posterior(
            rng_key,
            model,
            data,
            N_SAMPLES,
            alpha=1,
            sigma=sigma,
            T=T,
            gibbs_fn=make_gibbs_fn(data) if USE_GIBBS else None,
            gibbs_sites=['z'] if USE_GIBBS else None,
        )
        cluster_count = compute_n_clusters_distribution(z, T)
        plt.plot(t, cluster_count, label=r"$\sigma={}$".format(sigma))
        color = plt.gca().lines[-1].get_color()

        # Upper bound
        if explicit_ub is not None:
            upper_bound = explicit_ub(data, t, params_PY=(1, sigma))
            plt.plot(t[1:],
                     upper_bound[1:],
                     label=r"Upper bound $\sigma={}$".format(sigma),
                     color=color,
                     linestyle="dotted",
                     lw=1)

        # Prior
        prior = compute_PY_prior(1, sigma, [Npoints])[0]
        plt.plot(t,
                 prior[:T + 1],
                 label=r"Prior $\sigma={}$".format(sigma),
                 color=color,
                 linestyle="dashed",
                 lw=1)
        # plt.axvline(alpha*np.log(Npoints), color=color, lw=1, linestyle="dashed")

    plt.legend()
    plt.title(r"Impact of $\sigma$")
    plt.ylabel(r"$P(T_n=t|X_{1:N})$")
    plt.xlabel("Number of clusters")
    plt.show()
Exemple #3
0
def plot_clusters(data: np.ndarray):
    T = 10
    t = np.arange(T + 1)
    Npoints = 100

    rng_key = random.PRNGKey(0)

    x = data[:Npoints]
    z = sample_posterior(rng_key, multivariate_gaussian_DPMM, x, Nsamples=1, T=T, alpha=1)

    pca = PCA(n_components=2)
    x_pca = np.ascontiguousarray(pca.fit_transform(x))

    for c in np.unique(z):
        xc = x_pca[z == c]
        plt.scatter(xc[:, 0], xc[:, 1], alpha=.5)

    plt.show()
def make_synthetic_experiment(sample_data,
                              model,
                              make_gibbs_fn,
                              explicit_ub=None):
    """
    Template for small synthetic experiments with varying size of observed data. 

    Args:
        sample_data: sampling function
        model: model function

    """
    rng_key = random.PRNGKey(0)

    # Sampling parameters
    n_values = [100, 1000, 10000]

    # DPMM/PYMM parameters
    T = 20  # max number of component in the truncated stick breaking representation
    t = np.arange(T + 1)
    alpha = 1
    sigma = 0

    # Plotting parameters
    fig, (ax0, ax1) = plt.subplots(1, 2)

    priors = compute_PY_prior(alpha, sigma, n_values)
    for Npoints, prior in zip(n_values, priors):
        cluster_count = np.zeros(T + 1)  # cluster count histogram
        upper_bound = np.zeros(T + 1)

        cluster_size = np.zeros(Npoints + 1)

        # Repeat the experiment
        for _ in range(REPEATS):
            data = sample_data(rng_key, Npoints)
            z = sample_posterior(
                rng_key,
                model,
                data,
                N_SAMPLES,
                T=T,
                alpha=1,
                gibbs_fn=make_gibbs_fn(data) if USE_GIBBS else None,
                gibbs_sites=['z'] if USE_GIBBS else None,
            )

            cluster_count += compute_n_clusters_distribution(z, T)
            if explicit_ub is not None:
                upper_bound += explicit_ub(data, t, params_PY=(alpha, sigma))

            cluster_size += compute_cluster_size_distribution(z)

        cluster_count /= REPEATS
        cluster_size /= REPEATS

        # Plot cluster count histograms (ax0)
        ax0.plot(t, cluster_count, label=f"N={Npoints}")

        color = ax0.lines[-1].get_color()
        ax0.plot(t,
                 prior[:T + 1],
                 label=f"Prior N={Npoints}",
                 color=color,
                 linestyle='dashed',
                 lw=1)

        if explicit_ub is not None:
            upper_bound /= REPEATS
            ax0.plot(t[1:],
                     upper_bound[1:],
                     label=f"Upper bound N={Npoints}",
                     color=color,
                     linestyle='dotted',
                     lw=1)

        # Plot cluster size histograms (ax1)
        bins = np.linspace(0, 1, 10, endpoint=True)
        frac = np.arange(0, Npoints + 1) / Npoints

        # TODO : use an actual histogram ?
        # Overlaying histograms doesn't really look good.
        hist, edges = np.histogram(frac,
                                   bins,
                                   density=True,
                                   weights=cluster_size)
        ax1.plot(0.5 * (edges[1:] + edges[:-1]),
                 hist,
                 color=color,
                 label=f"N={Npoints}")

    ax0.axhline(y=1, color='black', linewidth=0.3, linestyle='dotted')
    ax0.set(title=r"Number of clusters",
            xlabel="$t$",
            ylabel=r"$P(T_n=t|X_{1:N})$")
    ax0.legend()

    ax1.set(xlabel="Fraction of total size", title="Size of clusters")
    ax1.legend()

    plt.show()