def statistic_convergence(nsamples=5000,ncomputepoints=50,nruns=50,ndims=10):
    # get samples
    data = np.zeros((ndims,ndims))
    data[np.roll(np.arange(ndims//2),1),np.arange(ndims//2)] = 10 # fill half the dims with data
    alpha = 2. # Dirichlet prior hyperparameter
    beta = 160. # MH proposal distribution parameter, set so acceptance rate is about 0.24 with ndims=10
    mhsamples, auxsamples = map(np.array,
            sampling.load_or_run_samples(nruns,nsamples,alpha,beta,data))

    # compute statistics
    (mhmeans, mhvars), (mh_truemean, mh_truevar), (mh_mean_ds, mh_var_ds) = \
            tests.get_statistic_convergence(mhsamples,ncomputepoints)
    (auxmeans, auxvars), (aux_truemean, aux_truevar), (aux_mean_ds, aux_var_ds) = \
            tests.get_statistic_convergence(auxsamples,ncomputepoints)

    # check that the estimated "true" statistics agree
    assert ((mh_truemean - aux_truemean)**2).sum() < 1e-5 \
            and ((mh_truevar - aux_truevar)**2).sum() < 1e-5

    # get time scaling
    aux_timing = timing.get_auxvar_timing(data=data,alpha=alpha)
    mh_timing = timing.get_mh_timing(data=data,beta=beta,alpha=alpha)

    # plot
    for samplerds, statisticname in zip(((aux_mean_ds,mh_mean_ds),(aux_var_ds,mh_var_ds)),('mean','variance')):
        # sample index scaling
        plt.figure()

        for ds, samplername, color in zip(samplerds, ['Aux. Var.','MH'],['b','g']):
            plt.plot(np.array(tests.chunk_indices(nsamples,ncomputepoints)),
                    ds.mean(0),color+'-',label='%s Sampler' % samplername)
            plt.plot(np.array(tests.chunk_indices(nsamples,ncomputepoints)),
                    scoreatpercentile(ds,per=10,axis=0),color+'--')
            plt.plot(np.array(tests.chunk_indices(nsamples,ncomputepoints)),
                    scoreatpercentile(ds,per=90,axis=0),color+'--')

        plt.legend()
        plt.xlabel('sample index')
        plt.title('%s Convergence' % statisticname.capitalize())

        save('../writeup/figures/statisticconvergence_%dD_%s.pdf' % (ndims,statisticname))


        # time scaling
        plt.figure()

        for ds, samplername, color, timescaling in zip(samplerds, ['Aux. Var.','MH'],['b','g'],
                (aux_timing,mh_timing)):
            plt.plot(np.array(tests.chunk_indices(nsamples,ncomputepoints))*timescaling,
                    ds.mean(0),color+'-',label='%s Sampler' % samplername)
            plt.plot(np.array(tests.chunk_indices(nsamples,ncomputepoints))*timescaling,
                    scoreatpercentile(ds,per=10,axis=0),color+'--')
            plt.plot(np.array(tests.chunk_indices(nsamples,ncomputepoints))*timescaling,
                    scoreatpercentile(ds,per=90,axis=0),color+'--')

        plt.legend()
        plt.xlabel('seconds')
        plt.title('%s Convergence' % statisticname.capitalize())

        save('../writeup/figures/statisticconvergence_timescaling_%dD_%s.pdf' % (ndims,statisticname))
def Rhatp(nsamples=1000,ncomputepoints=25,nruns=50,ndims=10):
    # get samples
    data = np.zeros((ndims,ndims))
    data[np.roll(np.arange(ndims//2),1),np.arange(ndims//2)] = 10 # fill half the dims with data
    alpha = 2. # Dirichlet prior hyperparameter
    beta = 160. # MH proposal distribution parameter, set so acceptance rate is about 0.24 with ndims=10
    mhsamples, auxsamples = map(np.array,
            sampling.load_or_run_samples(nruns,nsamples,alpha,beta,data))

    # get Rhatps
    aux_R = tests.get_Rhat(auxsamples,ncomputepoints=ncomputepoints)
    mh_R = tests.get_Rhat(mhsamples,ncomputepoints=ncomputepoints)

    ### plot without time scaling
    plt.figure()

    # plt.subplot(2,1,1)
    plt.plot(tests.chunk_indices(nsamples,ncomputepoints),aux_R,'bx-',label='Aux. Var. Sampler')
    plt.plot(tests.chunk_indices(nsamples,ncomputepoints),mh_R,'gx-',label='MH Sampler')
    plt.ylim(0,1.1*mh_R.max())
    plt.xlim(0,1000)
    plt.xlabel('sample index')
    plt.legend()
    plt.title('MH and Aux. Var. Samplers MSPRF vs Sample Indices')

    # plt.subplot(2,1,2)
    # plt.plot(tests.chunk_indices(nsamples,ncomputepoints),aux_R,'bx-')
    # plt.ylim(0,1.1*aux_R.max())
    # plt.xlim(0,closeindex)
    # plt.xlabel('sample index')
    # plt.title('Aux. Var. Sampler MSPRF vs Sample Indices')

    save('../writeup/figures/MSPRF_sampleindexscaling_%dD.pdf' % ndims)

    ### plot with time scaling
    plt.figure()

    # compute time per sample
    aux_timing = timing.get_auxvar_timing(data=data,alpha=alpha)
    mh_timing = timing.get_mh_timing(data=data,beta=beta,alpha=alpha)

    plt.plot(np.array(tests.chunk_indices(nsamples,ncomputepoints))*aux_timing,
            aux_R,'bx-',label='Aux. Var. Sampler')
    plt.plot(np.array(tests.chunk_indices(nsamples,ncomputepoints))*mh_timing,
            mh_R,'gx-',label='MH Sampler')
    plt.ylim(0,1.1*mh_R.max())
    plt.xlim(0,mh_timing*nsamples)
    plt.xlabel('seconds')
    plt.legend()
    plt.title('MH and Aux. Var. Sampler MSPRF vs Computation Time')

    save('../writeup/figures/MSPRF_timescaling_%dD.pdf' % ndims)