def wl_is_faster_hamming():
    # show hamming error to true state sequence decreases faster with wl

    ### get samples
    wl_samples = get_hdphsmm_wl_poisson_samples(hsmm_data, nruns=100, niter=300, L=10)
    da_samples = get_hdphsmm_da_poisson_samples(hsmm_data, nruns=24, niter=150)

    ### get hamming errors for samples
    def f(tup):
        return util.stateseq_hamming_error(tup[0], tup[1])

    wl_errs = np.array(dv.map_sync(f, zip(wl_samples, [hsmm_labels] * len(wl_samples))))
    da_errs = np.array(dv.map_sync(f, zip(da_samples, [hsmm_labels] * len(da_samples))))

    ### plot
    plt.figure()

    for errs, samplername, color in zip([wl_errs, da_errs], ["Weak Limit", "Direct Assignment"], ["b", "g"]):
        plt.plot(np.median(errs, axis=0), color + "-", label="%s Sampler" % samplername)
        plt.plot(util.scoreatpercentile(errs.copy(), per=25, axis=0), color + "--")
        plt.plot(util.scoreatpercentile(errs.copy(), per=75, axis=0), color + "--")

    plt.legend()
    plt.xlabel("iteration")
    plt.ylabel("Hamming error")

    save("figures/wl_is_faster_hamming.pdf")

    return wl_errs, da_errs
def hsmm_vs_stickyhmm():
    # show convergence rates in #iter are same

    ### get samples
    hsmm_samples = get_hdphsmm_da_geo_samples(hmm_data, nruns=50, niter=100)
    shmm_samples = get_shdphmm_da_samples(hmm_data, nruns=50, niter=100)

    ### get hamming errors for samples
    def f(tup):
        return util.stateseq_hamming_error(tup[0], tup[1])

    hsmm_errs = np.array(dv.map_sync(f, zip(hsmm_samples, [hmm_labels] * len(hsmm_samples))))
    shmm_errs = np.array(dv.map_sync(f, zip(shmm_samples, [hmm_labels] * len(shmm_samples))))

    ### plot
    plt.figure()

    for errs, samplername, color in zip([hsmm_errs, shmm_errs], ["Geo-HDP-HSMM DA", "Sticky-HDP-HMM DA"], ["b", "g"]):
        plt.plot(np.median(errs, axis=0), color + "-", label="%s Sampler" % samplername)
        plt.plot(util.scoreatpercentile(errs.copy(), per=25, axis=0), color + "--")
        plt.plot(util.scoreatpercentile(errs.copy(), per=75, axis=0), color + "--")

    plt.legend()
    plt.xlabel("iteration")
    plt.ylabel("Hamming error")

    save("figures/hsmm_vs_stickyhmm.pdf")

    return hsmm_errs, shmm_errs