def get_kldivs(chains,ncomputepoints,meshsize=100,params={'alpha':2.,'beta':30.,'data':np.array([[0,2,0],[0,0,0],[0,0,0]])}):
    alpha, beta, data = params['alpha'], params['beta'], params['data']
    p = chains.shape[2]
    assert p == 3, 'this test only works on 3 dimensional examples'

    ### construct a 'true' density object by discrete approximate integration
    # get density evaluated on a mesh, (mesh3D, dvals)
    mesh3D = mesh(meshsize)
    dvals = log_censored_dirichlet_density(mesh3D, alpha, data=data)
    dvals = np.exp(dvals - dvals.max())
    dvals /= dvals.sum()

    # interpolate into a density function
    true_density = kde(0.05,mesh3D,dvals)

    ### get kl divergence to truth at cmopute points
    # preallocate outputs
    dists = np.zeros((chains.shape[0],ncomputepoints))

    # loop over chains
    for chainidx, chain in enumerate(chains):
        # loop over chunks
        for chunkidx, sampleidx in enumerate(chunk_indices(chain.shape[0], ncomputepoints)):
            # get relevant samples
            samples = chain[sampleidx//2:sampleidx]
            # compute kldiv against true_density
            dists[chainidx,chunkidx] = kldist_samples(samples,true_density)

    return dists
def aux_posterior_2D(meshsize=250,alpha=2.,data=np.array([[0,2,0],[0,0,0],[0,0,0]])):
    assert data.shape == (3,3)

    mesh3D = simplex.mesh(meshsize)
    mesh2D = simplex.proj_to_2D(mesh3D) # use specialized b/c it plays nicer with triangulation algorithm

    # get samples
    auxsamples = sampling.generate_pi_samples_withauxvars(alpha,10000,data)

    # evaluate a kde based on the samples
    aux_kde = density.kde(0.005,auxsamples[len(auxsamples)//20:])
    aux_kde_vals = aux_kde(mesh3D)

    ### plot

    # used for grid interpolation
    xi = np.linspace(mesh2D[:,0].min(), mesh2D[:,0].max(), 2000, endpoint=True)
    yi = np.linspace(mesh2D[:,1].min(), mesh2D[:,1].max(), 2000, endpoint=True)

    plt.figure(figsize=(8,8))
    plt.imshow(griddata((mesh2D[:,0],mesh2D[:,1]),aux_kde_vals,(xi[na,:],yi[:,na]),method='cubic'))
    plt.axis('off')

    save('../writeup/figures/dirichlet_censored_auxvar_posterior_2D.pdf')