def generate_pi_samples_mh(alpha,n_samples,data,beta):
    # starttime = time.time()
    K = data.shape[0]

    # randomly initialize pi
    pi = np.random.dirichlet(alpha * np.ones(K))
    current_val = log_censored_dirichlet_density(pi,alpha=alpha,data=data)

    samples = []

    n_accepts = 0

    # loop mh proposals
    n_total = 0
    while len(samples) < n_samples:
        ### make a proposal
        pi_prime = np.random.dirichlet(beta * pi)
        n_total += 1
        ### get proposal probability and sample it
        new_val = log_censored_dirichlet_density(pi_prime,alpha=alpha,data=data)
        if new_val > -np.inf: # in our tests, this is always true
            a = min(1.,np.exp(new_val - current_val
                + log_dirichlet_density(pi,alpha=beta*pi_prime)
                  - log_dirichlet_density(pi_prime,alpha=beta*pi)))
            if np.random.rand() < a:
                n_accepts += 1
                pi = pi_prime
                current_val = new_val

            samples.append(pi)

    # print 'done drawing samples in %0.2f seconds' % (time.time() - starttime)
    # print '%d total proposals, %d valid proposals, %d accepted, valid acceptance ratio %0.4f' % (n_total,n_samples, n_accepts, n_accepts / n_samples)

    return samples
def prior_posterior_2D(meshsize=250,alpha=2.,data=np.array([[0,2,0],[0,0,0],[0,0,0]])):
    assert data.shape == (3,3)

    mesh3D = simplex.mesh(meshsize)
    mesh2D = simplex.proj_to_2D(mesh3D) # use specialized b/c it plays nicer with triangulation algorithm

    priorvals = np.exp(dirichlet.log_dirichlet_density(mesh3D,alpha))

    posteriorvals_uncensored = np.exp(dirichlet.log_dirichlet_density(mesh3D,alpha,data=data.sum(0)))

    temp = dirichlet.log_censored_dirichlet_density(mesh3D,alpha,data=data)
    temp = np.exp(temp - temp.max())
    posteriorvals_censored = temp/temp.sum() # direct discretized integration!

    # used for grid interpolation
    xi = np.linspace(mesh2D[:,0].min(), mesh2D[:,0].max(), 2000, endpoint=True)
    yi = np.linspace(mesh2D[:,1].min(), mesh2D[:,1].max(), 2000, endpoint=True)

    plt.figure(figsize=(8,8))
    # use exactly one of the next two code lines!
    # this one performs interpolation to get a rectangular-pixel grid, but
    # produces a blurred image
    plt.imshow(griddata((mesh2D[:,0],mesh2D[:,1]),priorvals,(xi[na,:],yi[:,na]),method='cubic'))
    # this one exactly represents the data by performing a DeLaunay
    # triangulation, but it must draw each triangular pixel individually,
    # resulting in large files and slow draw times
    # plt.tripcolor(mesh2D[:,0],mesh2D[:,1],priorvals) # exact triangles, no blurring
    plt.axis('off')
    save('../writeup/figures/dirichlet_prior_2D.pdf')

    plt.figure(figsize=(8,8))
    plt.imshow(griddata((mesh2D[:,0],mesh2D[:,1]),posteriorvals_uncensored,(xi[na,:],yi[:,na]),method='cubic'))
    # plt.tripcolor(mesh2D[:,0],mesh2D[:,1],posteriorvals_uncensored)
    plt.axis('off')
    save('../writeup/figures/dirichlet_uncensored_posterior_2D.pdf')

    plt.figure(figsize=(8,8))
    plt.imshow(griddata((mesh2D[:,0],mesh2D[:,1]),posteriorvals_censored,(xi[na,:],yi[:,na]),method='cubic'))
    # plt.tripcolor(mesh2D[:,0],mesh2D[:,1],posteriorvals_censored)
    plt.axis('off')
    save('../writeup/figures/dirichlet_censored_posterior_2D.pdf')