예제 #1
0
def plot_mean_locations():
    ### Plot the sampled locations for a few neurons
    _, _, _, _, Ls = result
    Ls_rot = []
    for L in Ls:
        R = compute_optimal_rotation(L, pfs)
        Ls_rot.append(L.dot(R))
    Ls_rot = np.array(Ls_rot)

    Ls_mean = np.mean(Ls_rot, 0)

    fig = create_figure(figsize=(1.4, 2.5))
    plt.subplot(211, aspect='equal')

    wheel_cmap = gradient_cmap(
        [colors[0], colors[3], colors[2], colors[1], colors[0]])

    for i, k in enumerate(node_perm):
        color = wheel_cmap((np.pi + pfs_th[k]) / (2 * np.pi))
        plt.plot(pfs[k, 0],
                 pfs[k, 1],
                 'o',
                 markerfacecolor=color,
                 markeredgecolor=color,
                 markersize=4 + 4 * pf_size[k],
                 alpha=0.7)

    #     plt.gca().add_patch(Circle((0,0), radius=rad, ec='k', fc="none"))
    plt.title("True Place Fields")
    plt.xlim(-45, 45)
    # plt.xlabel("$x$")
    plt.xticks([-40, -20, 0, 20, 40], [])
    plt.ylim(-45, 45)
    # plt.ylabel("$y$")
    plt.yticks([-40, -20, 0, 20, 40], [])

    # Now plot the inferred locations

    plt.subplot(212, aspect='equal')

    for i, k in enumerate(node_perm):
        color = wheel_cmap((np.pi + pfs_th[k]) / (2 * np.pi))
        plt.plot(Ls_mean[k, 0],
                 Ls_mean[k, 1],
                 'o',
                 markerfacecolor=color,
                 markeredgecolor=color,
                 markersize=4 + 4 * pf_size[k],
                 alpha=0.7)

    #     plt.gca().add_patch(Circle((0,0), radius=rad, ec='k', fc="none"))
    plt.title("Mean Locations")
    plt.xlim(-30, 30)
    plt.xticks([])
    plt.ylim(-30, 30)
    plt.yticks([])

    plt.tight_layout()
    plt.savefig(os.path.join(results_dir, "hipp_mean_locations.pdf"))
    plt.show()
def plot_impulse_responses(models, results):
    from hips.plotting.layout import create_figure
    from hips.plotting.colormaps import harvard_colors

    # Make the ICML figure
    fig = create_figure((6, 6))
    col = harvard_colors()
    plt.grid()

    y_max = 0

    for i, (model, result) in enumerate(zip(models, results)):
        smpl = result.samples[-1]
        W = smpl.W_effective
        if "continuous" in str(smpl.__class__).lower():
            t, irs = smpl.impulses

            for k1 in range(K):
                for k2 in range(K):
                    plt.subplot(K, K, k1 * K + k2 + 1)
                    plt.plot(t, W[k1, k2] * irs[:, k1, k2], color=col[i], lw=2)
        else:
            irs = smpl.impulses
            for k1 in range(K):
                for k2 in range(K):
                    plt.subplot(K, K, k1 * K + k2 + 1)
                    plt.plot(W[k1, k2] * irs[:, k1, k2], color=col[i], lw=2)

        y_max = max(y_max, (W * irs).max())

    for k1 in range(K):
        for k2 in range(K):
            plt.subplot(K, K, k1 * K + k2 + 1)
            plt.ylim(0, y_max * 1.05)
    plt.show()
def plot_pred_ll_vs_time(models, results, burnin=0,
                         std_ll=np.nan,
                         true_ll=np.nan):
    from hips.plotting.layout import create_figure
    from hips.plotting.colormaps import harvard_colors

    # Make the ICML figure
    fig = create_figure((4,3))
    ax = fig.add_subplot(111)
    col = harvard_colors()
    plt.grid()

    t_start = 0
    t_stop = 0

    for i, (model, result) in enumerate(zip(models, results)):
        plt.plot(result.timestamps[burnin:], result.test_lls[burnin:], lw=2, color=col[i], label=model)

        # Update time limits
        t_start = min(t_start, result.timestamps[burnin:].min())
        t_stop = max(t_stop, result.timestamps[burnin:].max())

    # plt.legend(loc="outside right")

    # Plot the standard Hawkes test ll
    plt.plot([t_start, t_stop], std_ll*np.ones(2), lw=2, color=col[len(models)], label="Std.")

    # Plot the true ll
    plt.plot([t_start, t_stop], true_ll*np.ones(2), '--k', lw=2, label="True")

    ax.set_xlim(t_start, t_stop)
    ax.set_xlabel("time [sec]")
    ax.set_ylabel("Pred. Log Lkhd.")
    plt.show()
def plot_roc_curves(fprs, tprs):
    from hips.plotting.layout import create_figure
    from hips.plotting.colormaps import harvard_colors
    col = harvard_colors()

    fig = create_figure((3,3))
    ax = fig.add_subplot(111)
    ax.plot(fprs['xcorr'], tprs['xcorr'], color=col[7], lw=1.5, label="xcorr")
    ax.plot(fprs['bfgs'], tprs['bfgs'], color=col[3], lw=1.5, label="Std.")
    ax.plot(fprs['svi'], tprs['svi'], color=col[0], lw=1.5, label="MAP")
    ax.plot([0,1], [0,1], '-k', lw=0.5)
    ax.set_xlabel("FPR")
    ax.set_ylabel("TPR")

    # this is another inset axes over the main axes
    parchment = np.array([243,243,241])/255.
    inset = plt.axes([0.55, 0.275, .265, .265], axisbg=parchment)
    inset.plot(fprs['xcorr'], tprs['xcorr'], color=col[7], lw=1.5,)
    inset.plot(fprs['bfgs'], tprs['bfgs'], color=col[3], lw=1.5,)
    inset.plot(fprs['svi'], tprs['svi'], color=col[0], lw=1.5, )
    inset.plot([0,1], [0,1], '-k', lw=0.5)
    plt.setp(inset, xlim=(0,.2), ylim=(0,.2), xticks=[0, 0.2], yticks=[0,0.2], aspect=1.0)
    inset.yaxis.tick_right()

    plt.legend(loc=4)
    ax.set_title("ROC Curve")

    plt.subplots_adjust(bottom=0.2, left=0.2)

    plt.savefig("figure3c.pdf")
    plt.show()
def plot_impulse_responses(models, results):
    from hips.plotting.layout import create_figure
    from hips.plotting.colormaps import harvard_colors

    # Make the ICML figure
    fig = create_figure((6,6))
    col = harvard_colors()
    plt.grid()

    y_max = 0

    for i, (model, result) in enumerate(zip(models, results)):
        smpl = result.samples[-1]
        W = smpl.W_effective
        if "continuous" in str(smpl.__class__).lower():
            t, irs = smpl.impulses

            for k1 in xrange(K):
                for k2 in xrange(K):
                    plt.subplot(K,K,k1*K + k2 + 1)
                    plt.plot(t, W[k1,k2] * irs[:,k1,k2], color=col[i], lw=2)
        else:
            irs = smpl.impulses
            for k1 in xrange(K):
                for k2 in xrange(K):
                    plt.subplot(K,K,k1*K + k2 + 1)
                    plt.plot(W[k1,k2] * irs[:,k1,k2], color=col[i], lw=2)

        y_max = max(y_max, (W*irs).max())

    for k1 in xrange(K):
        for k2 in xrange(K):
            plt.subplot(K,K,k1*K+k2+1)
            plt.ylim(0,y_max*1.05)
    plt.show()
def plot_prc_curves(precs, recalls, fig_path="./"):
    from hips.plotting.layout import create_figure
    from hips.plotting.colormaps import harvard_colors
    col = harvard_colors()

    fig = create_figure((3, 3))
    ax = fig.add_subplot(111)
    if "xcorr" in recalls:
        ax.plot(recalls['xcorr'],
                precs['xcorr'],
                color=col[7],
                lw=1.5,
                label="xcorr")
    if "bfgs" in recalls:
        ax.plot(recalls['bfgs'],
                precs['bfgs'],
                color=col[3],
                lw=1.5,
                label="MAP")
    if "svi" in recalls:
        ax.plot(recalls['svi'],
                precs['svi'],
                color=col[0],
                lw=1.5,
                label="SVI")
    ax.set_xlabel("Recall")
    ax.set_ylabel("Precision")

    plt.legend(loc=1)
    ax.set_title("Network %d" % net)
    plt.subplots_adjust(bottom=0.25, left=0.25)

    plt.savefig(os.path.join(os.path.dirname(fig_path), "figure3d.pdf"))
    plt.show()
예제 #7
0
def plot_results(alpha_a_0s, Ks_alpha_a_0,
                 gamma_a_0s, Ks_gamma_a_0,
                 figdir="."):

    # Plot the number of inferred states as a function of params
    fig = create_figure((5,1.5))

    ax = create_axis_at_location(fig, 0.6, 0.5, 1.7, .8, transparent=True)
    plt.figtext(0.05/5, 1.25/1.5, "A")
    ax.boxplot(Ks_alpha_a_0, positions=np.arange(1,1+len(alpha_a_0s)),
               boxprops=dict(color=allcolors[1]),
               whiskerprops=dict(color=allcolors[0]),
               flierprops=dict(color=allcolors[1]))
    ax.set_xticklabels(alpha_a_0s)
    plt.xlim(0.5,4.5)
    plt.ylim(40,90)
    # plt.yticks(np.arange(0,101,20))
    ax.set_xlabel("$a_{\\alpha_0}$")
    ax.set_ylabel("Number of States")

    ax = create_axis_at_location(fig, 3.1, 0.5, 1.7, .8, transparent=True)
    plt.figtext(2.55/5, 1.25/1.5, "B")
    ax.boxplot(Ks_gamma_a_0, positions=np.arange(1,1+len(gamma_a_0s)),
               boxprops=dict(color=allcolors[1]),
               whiskerprops=dict(color=allcolors[0]),
               flierprops=dict(color=allcolors[1]))
    ax.set_xticklabels(gamma_a_0s)
    plt.xlim(0.5,4.5)
    plt.ylim(40,90)
    # plt.yticks(np.arange(0,101,20))
    ax.set_xlabel("$a_{\\gamma}$")
    ax.set_ylabel("Number of States")

    plt.savefig(os.path.join(figdir, "figure7.pdf"))
    plt.savefig(os.path.join(figdir, "figure7.png"))
예제 #8
0
def plot_pred_log_likelihood(
        timestamp_list,
        pred_ll_list,
        names,
        results_dir,
        outname="ctm_pred_ll_vs_time.pdf",
        title=None,
        smooth=True,
        burnin=3,
        normalizer=4632.  # Number of words in test dataset
):
    # Plot the log likelihood
    width = 5.25 / 3.  # Three NIPS panels
    fig = create_figure(figsize=(width, 2.25), transparent=True)
    fig.set_tight_layout(True)

    min_time = np.min([np.min(times[burnin + 2:]) for times in timestamp_list])
    max_time = np.max([np.max(times[burnin + 2:]) for times in timestamp_list])

    for i, (times, pred_ll,
            name) in enumerate(zip(timestamp_list, pred_ll_list, names)[::-1]):

        # Smooth the log likelihood
        smooth_pred_ll = logma(pred_ll)
        plt.plot(np.clip(times[burnin + 2:], 1e-3, np.inf),
                 smooth_pred_ll[burnin:] / normalizer,
                 lw=2,
                 color=colors[3 - i],
                 label=name)

        # plt.plot(np.clip(times[burnin:], 1e-3,np.inf),
        #          pred_ll[burnin:] / normalizer,
        #          lw=2, color=colors[3-i], label=None)

        N = len(pred_ll)
        avg_pll = logsumexp(pred_ll[N // 2:]) - np.log(N - N // 2)

        plt.plot([min_time, max_time],
                 avg_pll / normalizer * np.ones(2),
                 ls='--',
                 color=colors[3 - i])

    plt.xlabel('Time [s] (log scale) ', fontsize=9)
    plt.xscale("log")
    plt.xlim(min_time, max_time)

    plt.ylabel("Pred. Log Lkhd. [nats/word]", fontsize=9)
    plt.ylim(-2.6, -2.47)
    plt.yticks([-2.6, -2.55, -2.5])
    # plt.subplots_adjust(left=0.05)

    if title:
        plt.title(title)

    # plt.ylim(-9.,-8.4)
    plt.savefig(os.path.join(results_dir, outname))
    plt.show()
예제 #9
0
def plot_mean_locations():
    ### Plot the sampled locations for a few neurons
    _, _, _, _, Ls = result
    Ls_rot = []
    for L in Ls:
        R = compute_optimal_rotation(L, pfs)
        Ls_rot.append(L.dot(R))
    Ls_rot = np.array(Ls_rot)

    Ls_mean = np.mean(Ls_rot, 0)

    fig = create_figure(figsize=(1.4,2.5))
    plt.subplot(211, aspect='equal')

    wheel_cmap = gradient_cmap([colors[0], colors[3], colors[2], colors[1], colors[0]])

    for i,k in enumerate(node_perm):
        color = wheel_cmap((np.pi+pfs_th[k])/(2*np.pi))
        plt.plot(pfs[k,0], pfs[k, 1], 'o',
                 markerfacecolor=color, markeredgecolor=color,
                 markersize=4 + 4 * pf_size[k],
                 alpha=0.7)

    #     plt.gca().add_patch(Circle((0,0), radius=rad, ec='k', fc="none"))
    plt.title("True Place Fields")
    plt.xlim(-45, 45)
    # plt.xlabel("$x$")
    plt.xticks([-40, -20, 0, 20, 40], [])
    plt.ylim(-45, 45)
    # plt.ylabel("$y$")
    plt.yticks([-40, -20, 0, 20, 40], [])

    # Now plot the inferred locations


    plt.subplot(212, aspect='equal')

    for i,k in enumerate(node_perm):
        color = wheel_cmap((np.pi+pfs_th[k])/(2*np.pi))
        plt.plot(Ls_mean[k,0], Ls_mean[k, 1], 'o',
                 markerfacecolor=color, markeredgecolor=color,
                 markersize=4 + 4 * pf_size[k],
                 alpha=0.7)

    #     plt.gca().add_patch(Circle((0,0), radius=rad, ec='k', fc="none"))
    plt.title("Mean Locations")
    plt.xlim(-30, 30)
    plt.xticks([])
    plt.ylim(-30, 30)
    plt.yticks([])

    plt.tight_layout()
    plt.savefig(os.path.join(results_dir, "hipp_mean_locations.pdf"))
    plt.show()
예제 #10
0
def plot_pred_log_likelihood(timestamp_list,
                             pred_ll_list, names,
                             results_dir,
                             outname="ctm_pred_ll_vs_time.pdf",
                             title=None,
                             smooth=True, burnin=3,
                             normalizer=4632.       # Number of words in test dataset
                            ):
    # Plot the log likelihood
    width = 5.25/3.  # Three NIPS panels
    fig = create_figure(figsize=(width, 2.25), transparent=True)
    fig.set_tight_layout(True)

    min_time = np.min([np.min(times[burnin+2:]) for times in timestamp_list])
    max_time = np.max([np.max(times[burnin+2:]) for times in timestamp_list])

    for i,(times, pred_ll, name) in enumerate(zip(timestamp_list, pred_ll_list, names)[::-1]):

        # Smooth the log likelihood
        smooth_pred_ll = logma(pred_ll)
        plt.plot(np.clip(times[burnin+2:], 1e-3,np.inf),
                 smooth_pred_ll[burnin:] / normalizer,
                 lw=2, color=colors[3-i], label=name)

        # plt.plot(np.clip(times[burnin:], 1e-3,np.inf),
        #          pred_ll[burnin:] / normalizer,
        #          lw=2, color=colors[3-i], label=None)

        N = len(pred_ll)
        avg_pll = logsumexp(pred_ll[N//2:]) - np.log(N-N//2)

        plt.plot([min_time, max_time], avg_pll / normalizer * np.ones(2), ls='--', color=colors[3-i])


    plt.xlabel('Time [s] (log scale) ', fontsize=9)
    plt.xscale("log")
    plt.xlim(min_time, max_time)

    plt.ylabel("Pred. Log Lkhd. [nats/word]", fontsize=9)
    plt.ylim(-2.6, -2.47)
    plt.yticks([-2.6, -2.55, -2.5])
    # plt.subplots_adjust(left=0.05)

    if title:
        plt.title(title)

    # plt.ylim(-9.,-8.4)
    plt.savefig(os.path.join(results_dir, outname))
    plt.show()
예제 #11
0
def make_figure_a(S, F, C):
    """
    Plot fluorescence traces, filtered fluorescence, and spike times
    for three neurons
    """
    col = harvard_colors()
    dt = 0.02
    T_start = 0
    T_stop = 1 * 50 * 60
    t = dt * np.arange(T_start, T_stop)

    ks = [0, 1]
    nk = len(ks)
    fig = create_figure((3, 3))
    for ind, k in enumerate(ks):
        ax = fig.add_subplot(nk, 1, ind + 1)
        ax.plot(t, F[T_start:T_stop, k], color=col[1],
                label="$F$")  # Plot the raw flourescence in blue
        ax.plot(t,
                C[T_start:T_stop, k],
                color=col[0],
                lw=1.5,
                label="$\widehat{F}$")  # Plot the filtered flourescence in red
        spks = np.where(S[T_start:T_stop, k])[0]
        ax.plot(t[spks], C[spks, k], 'ko',
                label="S")  # Plot the spike times in black

        # Make a legend
        if ind == 0:
            # Put a legend above
            plt.legend(bbox_to_anchor=(0., 1.02, 1., .102),
                       loc=3,
                       ncol=3,
                       mode="expand",
                       borderaxespad=0.,
                       prop={'size': 9})

        # Add labels
        ax.set_ylabel("$F_%d(t)$" % (k + 1))
        if ind == nk - 1:
            ax.set_xlabel("Time $t$ [sec]")

        # Format the ticks
        ax.set_ylim([-0.1, 1.0])
        plt.locator_params(nbins=5, axis="y")

    plt.subplots_adjust(left=0.2, bottom=0.2)
    fig.savefig("figure3a.pdf")
    plt.show()
def plot_pred_ll_vs_time(models,
                         results,
                         burnin=0,
                         std_ll=np.nan,
                         true_ll=np.nan):
    from hips.plotting.layout import create_figure
    from hips.plotting.colormaps import harvard_colors

    # Make the ICML figure
    fig = create_figure((4, 3))
    ax = fig.add_subplot(111)
    col = harvard_colors()
    plt.grid()

    t_start = 0
    t_stop = 0

    for i, (model, result) in enumerate(zip(models, results)):
        plt.plot(result.timestamps[burnin:],
                 result.test_lls[burnin:],
                 lw=2,
                 color=col[i],
                 label=model)

        # Update time limits
        t_start = min(t_start, result.timestamps[burnin:].min())
        t_stop = max(t_stop, result.timestamps[burnin:].max())

    # plt.legend(loc="outside right")

    # Plot the standard Hawkes test ll
    plt.plot([t_start, t_stop],
             std_ll * np.ones(2),
             lw=2,
             color=col[len(models)],
             label="Std.")

    # Plot the true ll
    plt.plot([t_start, t_stop],
             true_ll * np.ones(2),
             '--k',
             lw=2,
             label="True")

    ax.set_xlim(t_start, t_stop)
    ax.set_xlabel("time [sec]")
    ax.set_ylabel("Pred. Log Lkhd.")
    plt.show()
def plot_prc_curves(precs, recalls):
    from hips.plotting.layout import create_figure
    from hips.plotting.colormaps import harvard_colors
    col = harvard_colors()

    fig = create_figure((3,3))
    ax = fig.add_subplot(111)
    ax.plot(recalls['xcorr'], precs['xcorr'], color=col[7], lw=1.5, label="xcorr")
    ax.plot(recalls['bfgs'], precs['bfgs'], color=col[3], lw=1.5, label="MAP")
    ax.plot(recalls['svi'], precs['svi'], color=col[0], lw=1.5, label="SVI")
    ax.set_xlabel("Recall")
    ax.set_ylabel("Precision")

    plt.legend(loc=1)
    ax.set_title("Precision-Recall Curve")
    plt.subplots_adjust(bottom=0.25, left=0.25)

    plt.savefig("figure3d.pdf")
    plt.show()
예제 #14
0
def make_figure_a(S, F, C):
    """
    Plot fluorescence traces, filtered fluorescence, and spike times
    for three neurons
    """
    col = harvard_colors()
    dt = 0.02
    T_start = 0
    T_stop = 1 * 50 * 60
    t = dt * np.arange(T_start, T_stop)

    ks = [0,1]
    nk = len(ks)
    fig = create_figure((3,3))
    for ind,k in enumerate(ks):
        ax = fig.add_subplot(nk,1,ind+1)
        ax.plot(t, F[T_start:T_stop, k], color=col[1], label="$F$")    # Plot the raw flourescence in blue
        ax.plot(t, C[T_start:T_stop, k], color=col[0], lw=1.5, label="$\widehat{F}$")    # Plot the filtered flourescence in red
        spks  = np.where(S[T_start:T_stop, k])[0]
        ax.plot(t[spks], C[spks,k], 'ko', label="S")            # Plot the spike times in black

        # Make a legend
        if ind == 0:
            # Put a legend above
            plt.legend(bbox_to_anchor=(0., 1.02, 1., .102), loc=3,
                       ncol=3, mode="expand", borderaxespad=0.,
                       prop={'size':9})

        # Add labels
        ax.set_ylabel("$F_%d(t)$" % (k+1))
        if ind == nk-1:
            ax.set_xlabel("Time $t$ [sec]")

        # Format the ticks
        ax.set_ylim([-0.1,1.0])
        plt.locator_params(nbins=5, axis="y")


    plt.subplots_adjust(left=0.2, bottom=0.2)
    fig.savefig("figure3a.pdf")
    plt.show()
def plot_roc_curves(fprs, tprs, fig_path="./"):
    from hips.plotting.layout import create_figure
    from hips.plotting.colormaps import harvard_colors

    col = harvard_colors()

    fig = create_figure((3, 3))
    ax = fig.add_subplot(111)

    # Plot the ROC curves
    if "xcorr" in fprs:
        ax.plot(fprs["xcorr"], tprs["xcorr"], color=col[7], lw=1.5, label="xcorr")
    if "bfgs" in fprs:
        ax.plot(fprs["bfgs"], tprs["bfgs"], color=col[3], lw=1.5, label="MAP")
    if "svi" in fprs:
        ax.plot(fprs["svi"], tprs["svi"], color=col[0], lw=1.5, label="SVI")

    # Plot the diagonal
    ax.plot([0, 1], [0, 1], "-k", lw=0.5)
    ax.set_xlabel("FPR")
    ax.set_ylabel("TPR")

    # this is another inset axes over the main axes
    # parchment = np.array([243,243,241])/255.
    # inset = plt.axes([0.55, 0.275, .265, .265], axisbg=parchment)
    # inset.plot(fprs['xcorr'], tprs['xcorr'], color=col[7], lw=1.5,)
    # inset.plot(fprs['bfgs'], tprs['bfgs'], color=col[3], lw=1.5,)
    # inset.plot(fprs['svi'], tprs['svi'], color=col[0], lw=1.5, )
    # inset.plot([0,1], [0,1], '-k', lw=0.5)
    # plt.setp(inset, xlim=(0,.2), ylim=(0,.2), xticks=[0, 0.2], yticks=[0,0.2], aspect=1.0)
    # inset.yaxis.tick_right()

    plt.legend(loc=4)
    ax.set_title("ROC Curve")

    plt.subplots_adjust(bottom=0.2, left=0.2)

    plt.savefig(os.path.join(os.path.dirname(fig_path), "figure3c.pdf"))
    plt.show()
def plot_prc_curves(precs, recalls, fig_path="./"):
    from hips.plotting.layout import create_figure
    from hips.plotting.colormaps import harvard_colors

    col = harvard_colors()

    fig = create_figure((3, 3))
    ax = fig.add_subplot(111)
    if "xcorr" in recalls:
        ax.plot(recalls["xcorr"], precs["xcorr"], color=col[7], lw=1.5, label="xcorr")
    if "bfgs" in recalls:
        ax.plot(recalls["bfgs"], precs["bfgs"], color=col[3], lw=1.5, label="MAP")
    if "svi" in recalls:
        ax.plot(recalls["svi"], precs["svi"], color=col[0], lw=1.5, label="SVI")
    ax.set_xlabel("Recall")
    ax.set_ylabel("Precision")

    plt.legend(loc=1)
    ax.set_title("Network %d" % net)
    plt.subplots_adjust(bottom=0.25, left=0.25)

    plt.savefig(os.path.join(os.path.dirname(fig_path), "figure3d.pdf"))
    plt.show()
                ct_times.append(fit_continuous_time_model_gibbs(S_ct, C_ct, N_samples))

        with open(res_file, "w") as f:
            cPickle.dump((events_per_bin, dt_times, ct_times), f, protocol=-1)

    events_per_bin = np.array(events_per_bin)
    dt_times = np.array(dt_times)
    ct_times = np.array(ct_times)
    perm = np.argsort(events_per_bin)

    events_per_bin = events_per_bin[perm]
    dt_times = dt_times[perm]
    ct_times = ct_times[perm]

    # Plot the results
    fig = create_figure(figsize=(2.5,2.5))
    fig.set_tight_layout(True)
    ax = fig.add_subplot(111)

    # Plot DT data
    ax.plot(events_per_bin, dt_times, 'o', linestyle="none",
            markerfacecolor=colors[2], markeredgecolor=colors[2], markersize=4,
            label="Discrete")

    # Plot linear fit
    p_dt = np.poly1d(np.polyfit(events_per_bin, dt_times, deg=1))
    dt_pred = p_dt(events_per_bin)
    ax.plot(events_per_bin, dt_pred, ':', lw=2, color=colors[2])

    # Plot CT data
    ax.plot(events_per_bin, ct_times, 's', linestyle="none",
예제 #18
0
def draw_mixture_figure(Ns, Ss, z, lmbda, filename="figure1.png", saveargs=dict(dpi=300)):
    fig = create_figure((5.5, 2.7))
    ax = create_axis_at_location(fig, .75, .5, 4., 1.375)
    ymax = 105
    # Plot the rates
    for i in range(n):
        ax.add_patch(Rectangle([i*D,0], D, lmbda[z[i]],
                               color=colors[z[i]], ec="none", alpha=0.5))
        ax.plot([i*D, (i+1)*D], lmbda[z[i]] * np.ones(2), '-k', lw=2)

        if i < n-1:
            ax.plot([(i+1)*D, (i+1)*D], [lmbda[z[i]], lmbda[z[i+1]]], '-k', lw=2)
            
        # Plot boundaries
        ax.plot([(i+1)*D, (i+1)*D], [0, ymax], ':k', lw=1)
        
        
    # Plot x axis
    plt.plot([0,T], [0,0], '-k', lw=2)

    # Plot spike times
    for s in Ss:
        plt.plot([s,s], [0,60], '-ko', markerfacecolor='k', markersize=5)

    plt.xlabel("time [ms]")
    plt.ylabel("firing rate [Hz]")
    plt.xlim(0,T)
    plt.ylim(-5,ymax)

    ## Now plot the spike count above
    ax = create_axis_at_location(fig, .75, 2., 4., .25)
    for i in xrange(n):
        # Plot boundaries
        ax.plot([(i+1)*D, (i+1)*D], [0, 10], '-k', lw=1)
        ax.text(i*D + D/3.5, 3, "%d" % Ns[i], fontdict={"size":9})
    ax.set_xlim(0,T)
    ax.set_ylim(0,10)
    ax.set_xticks([])
    ax.set_yticks([])
    ax.yaxis.labelpad = 30
    ax.set_ylabel("${s_t}$", rotation=0,  verticalalignment='center')

    ## Now plot the latent state above that above
    ax = create_axis_at_location(fig, .75, 2.375, 4., .25)
    for i in xrange(n):
        # Plot boundaries
        ax.add_patch(Rectangle([i*D,0], D, 10,
                            color=colors[z[i]], ec="none", alpha=0.5))

        ax.plot([(i+1)*D, (i+1)*D], [0, 10], '-k', lw=1)
        ax.text(i*D + D/3.5, 3, "u" if z[i]==0 else "d", fontdict={"size":9})
    ax.set_xlim(0,T)
    ax.set_ylim(0,10)
    ax.set_xticks([])
    ax.set_yticks([])
    ax.yaxis.labelpad = 30
    ax.set_ylabel("${z_t}$", rotation=0,  verticalalignment='center')

    
    #fig.savefig(filename + ".pdf")
    fig.savefig(filename, **saveargs)

    plt.close(fig)
예제 #19
0
def plot_pred_log_likelihood(results, names, results_dir,
                             outname="pred_ll_vs_time.pdf",
                             smooth=True, burnin=0):

    # Get the baseline pred ll
    baseline = 0
    normalizer = 0
    for Xtr, Xte in zip(Xtrain, Xtest):
        pi_emp = Xtr.sum(0) / float(Xtr.sum())
        pi_emp = np.clip(pi_emp, 1e-8, np.inf)
        pi_emp /= pi_emp.sum()
        baseline += Multinomial(weights=pi_emp, K=Xtr.shape[1]).log_likelihood(Xte).sum()
        normalizer += Xte.sum()

    # Plot the log likelihood
    # fig = plt.figure(figsize=(2.25,2.5))
    fig = create_figure(figsize=(2.25, 2.5), transparent=True)
    fig.set_tight_layout(True)
    for i,(result, name) in enumerate(zip(results, names)):
        if result.pred_lls.ndim == 2:
            pred_ll = result.pred_lls[:,0]
        else:
            pred_ll = result.pred_lls


        # Smooth the log likelihood
        if smooth:
            win = 10
            pad_pred_ll = np.concatenate((pred_ll[0] * np.ones(win), pred_ll))
            smooth_pred_ll = np.array([logsumexp(pad_pred_ll[j-win:j+1])-np.log(win)
                                       for j in xrange(win, pad_pred_ll.size)])

            plt.plot(np.clip(result.timestamps[burnin:], 1e-3,np.inf),
                     (smooth_pred_ll[burnin:] - baseline) / normalizer,
                     lw=2, color=colors[i], label=name)

        else:
            plt.plot(np.clip(result.timestamps[burnin:], 1e-3,np.inf),
                     result.pred_lls[burnin:],
                     lw=2, color=colors[i], label=name)


        # if result.pred_lls.ndim == 2:
        #     plt.errorbar(np.clip(result.timestamps, 1e-3,np.inf),
        #                  result.pred_lls[:,0],
        #                  yerr=result.pred_lls[:,1],
        #                  lw=2, color=colors[i], label=name)
        # else:
        #     plt.plot(np.clip(result.timestamps, 1e-3,np.inf), result.pred_lls, lw=2, color=colors[i], label=name)


    xmin = 10**0
    xmax = 10**4.2
    plt.plot([xmin, xmax], np.zeros(2), ':k', lw=0.5)

    # plt.plot(gauss_lds_lls, lw=2, color=colors[2], label="Gaussian LDS")
    # plt.legend(loc="lower right")
    plt.xlabel('Time [sec] (log scale)')
    plt.xlim(xmin, xmax)
    plt.xscale("log")
    plt.ylabel("Pred. Log Lkhd. (nats/word)")
    plt.ylim(-4, 1)
    plt.title("Alice")
    plt.savefig(os.path.join(results_dir, outname))
예제 #20
0
def plot_pred_ll_vs_D(all_results, Ds, Xtrain, Xtest,
                      results_dir, models=None):
    # Create a big matrix of shape (len(Ds) x 5 x T) for the pred lls
    N = len(Ds)                             # Number of dimensions tests
    M = len(all_results[0])                 # Number of models tested
    T = len(all_results[0][0].pred_lls)     # Number of MCMC iters
    pred_lls = np.zeros((N,M,T))
    for n in xrange(N):
        for m in xrange(M):
            if all_results[n][m].pred_lls.ndim == 2:
                pred_lls[n,m] = all_results[n][m].pred_lls[:,0]
            else:
                pred_lls[n,m] = all_results[n][m].pred_lls

    # Compute the mean and standard deviation on burned in samples
    burnin = T // 2
    pred_ll_mean = logsumexp(pred_lls[:,:,burnin:], axis=-1) - np.log(T-burnin)

    # Use bootstrap to compute error bars
    pred_ll_std = np.zeros_like(pred_ll_mean)
    for n in xrange(N):
        for m in xrange(M):
            samples = np.random.choice(pred_lls[n,m,burnin:], size=(100, (T-burnin)), replace=True)
            pll_samples = logsumexp(samples, axis=1) - np.log(T-burnin)
            pred_ll_std[n,m] = pll_samples.std()

    # Get the baseline pred ll
    baseline = 0
    normalizer = 0
    for Xtr, Xte in zip(Xtrain, Xtest):
        pi_emp = Xtr.sum(0) / float(Xtr.sum())
        pi_emp = np.clip(pi_emp, 1e-8, np.inf)
        pi_emp /= pi_emp.sum()
        baseline += Multinomial(weights=pi_emp, K=Xtr.shape[1]).log_likelihood(Xte).sum()
        normalizer += Xte.sum()

    # Make a bar chart with errorbars
    from hips.plotting.layout import create_figure
    fig = create_figure(figsize=(1.25,2.5), transparent=True)
    fig.set_tight_layout(True)
    ax = fig.add_subplot(111)
    ax.spines['top'].set_visible(False)
    ax.spines['right'].set_visible(False)
    ax.get_xaxis().tick_bottom()
    ax.get_yaxis().tick_left()

    width = np.min(np.diff(Ds)) / (M+1.0) if len(Ds)>1 else 1.
    for m in xrange(M):
        ax.bar(Ds+m*width,
               (pred_ll_mean[:,m] - baseline) / normalizer,
               yerr=pred_ll_std[:,m] / normalizer,
               width=0.9*width, color=colors[m], ecolor='k')
        #
        # ax.text(Ds+(m-1)*width, yloc, rankStr, horizontalalignment=align,
        #     verticalalignment='center', color=clr, weight='bold')

    # Plot the zero line
    ax.plot([Ds.min()-width, Ds.max()+(M+1)*width], np.zeros(2), '-k')

    # Set the tick labels
    ax.set_xlim(Ds.min()-width, Ds.max()+(M+1)*width)
    # ax.set_xticks(Ds + (M*width)/2.)
    # ax.set_xticklabels(Ds)
    # ax.set_xticks(Ds + width * np.arange(M) + width/2. )
    # ax.set_xticklabels(models, rotation=45)
    ax.set_xticks([])

    # ax.set_xlabel("D")
    ax.set_ylabel("Pred. Log Lkhd. (nats/word)")
    ax.set_title("AP News")

    plt.savefig(os.path.join(results_dir, "pred_ll_vs_D.pdf"))
예제 #21
0
def plot_pred_ll_vs_time(models, results, burnin=0,
                         homog_ll=np.nan,
                         std_ll=np.nan,
                         nlin_ll=np.nan,
                         true_ll=np.nan,
                         baseline=0,
                         normalizer=1,
                         output_dir=".",
                         xlim=None, ylim=None,
                         title=None):
    from hips.plotting.layout import create_figure, create_axis_at_location
    from hips.plotting.colormaps import harvard_colors

    # Make the ICML figure
    fig = create_figure((3,2))
    ax = create_axis_at_location(fig, 0.7, 0.4, 2.25, 1.35)
    col = harvard_colors()
    ax.grid()

    t_start = 0
    t_stop = 0

    def standardize(x):
        return (x-baseline)/normalizer

    for i, (model, result) in enumerate(zip(models, results)):
        ax.plot(result.timestamps[burnin:],
                 standardize(result.test_lls[burnin:]),
                 lw=2, color=col[i], label=model)

        # Update time limits
        t_start = min(t_start, result.timestamps[burnin:].min())
        t_stop = max(t_stop, result.timestamps[burnin:].max())

    if xlim is not None:
        t_start = xlim[0]
        t_stop = xlim[1]

    # plt.legend(loc="outside right")

    # Plot baselines
    ax.plot([t_start, t_stop], standardize(homog_ll)*np.ones(2),
             lw=2, color='k', label="Homog")

    # Plot the standard Hawkes test ll
    ax.plot([t_start, t_stop], standardize(std_ll)*np.ones(2),
             lw=2, color=col[len(models)], label="Std.")

    # Plot the Nonlinear Hawkes test ll
    ax.plot([t_start, t_stop], standardize(nlin_ll)*np.ones(2),
             lw=2, ls='--', color=col[len(models)+1], label="Nonlinear")

    # Plot the true ll
    ax.plot([t_start, t_stop], standardize(true_ll)*np.ones(2),
             '--k',  lw=2,label="True")


    ax.set_xlabel("time [sec]")
    ax.set_ylabel("Pred. LL. [nats/event]")

    ax.set_xscale("log")
    if xlim is not None:
        ax.set_xlim(xlim)
    else:
        ax.set_xlim(t_start, t_stop)
    if ylim is not None:
        ax.set_ylim(ylim)

    if title is not None:
        ax.set_title(title)

    output_file = os.path.join(output_dir, "pred_ll_vs_time.pdf")
    fig.savefig(output_file)
    plt.show()
예제 #22
0
def plot_qualitative_results(X, key, psi_lds, z_lds):
    start = 50
    stop = 70

    # Get the corresponding protein labels
    import operator
    id_to_char = dict([(v, k) for k, v in key.items()])
    sorted_chars = [
        idc[1].upper()
        for idc in sorted(id_to_char.items(), key=operator.itemgetter(0))
    ]
    X_inds = np.where(X)[1]
    prot_str = [id_to_char[v].upper() for v in X_inds]

    from pgmult.utils import psi_to_pi
    pi_lds = psi_to_pi(psi_lds)

    # Plot the true and inferred states
    fig = create_figure(figsize=(3., 3.1))

    # Plot the string of protein labels
    # ax1 = create_axis_at_location(fig, 0.5, 2.5, 2.25, 0.25)
    # for n in xrange(start, stop):
    #     ax1.text(n, 0.5, prot_str[n].upper())
    # # ax1.get_xaxis().set_visible(False)
    # ax1.axis("off")
    # ax1.set_xlim([start-1,stop])
    # ax1.set_title("Protein Sequence")

    # ax2 = create_axis_at_location(fig, 0.5, 2.25, 2.25, 0.5)
    # ax2 = fig.add_subplot(311)
    # plt.imshow(X[start:stop,:].T, interpolation="none", vmin=0, vmax=1, cmap="Blues", aspect="auto")
    # ax2.set_title("One-hot Encoding")

    # ax3 = create_axis_at_location(fig, 0.5, 1.25, 2.25, 0.5)
    ax3 = fig.add_subplot(211)
    im3 = plt.imshow(np.kron(pi_lds[start:stop, :].T, np.ones((50, 50))),
                     interpolation="none",
                     vmin=0,
                     vmax=1,
                     cmap="Blues",
                     aspect="auto",
                     extent=(0, stop - start, K + 1, 1))
    # Circle true symbol
    from matplotlib.patches import Rectangle
    for n in xrange(start, stop):
        ax3.add_patch(
            Rectangle((n - start, X_inds[n] + 1),
                      1,
                      1,
                      facecolor="none",
                      edgecolor="k"))

    # Print protein labels on y axis
    # ax3.set_yticks(np.arange(K))
    # ax3.set_yticklabels(sorted_chars)

    # Print protein sequence as xticks
    ax3.set_xticks(0.5 + np.arange(0, stop - start))
    ax3.set_xticklabels(prot_str[start:stop])
    ax3.xaxis.tick_top()
    ax3.xaxis.set_tick_params(width=0)

    ax3.set_yticks(0.5 + np.arange(1, K + 1, 5))
    ax3.set_yticklabels(np.arange(1, K + 1, 5))
    ax3.set_ylabel("$k$")

    ax3.set_title("Inferred Protein Probability", y=1.25)

    # Add a colorbar
    from mpl_toolkits.axes_grid1 import make_axes_locatable
    divider = make_axes_locatable(ax3)
    cax = divider.append_axes("right", size="3%", pad=0.05)
    cbar = plt.colorbar(im3, cax=cax, ticks=[0, 0.25, 0.5, 0.75, 1])
    cbar.set_label("Probability", labelpad=10)

    # ax4 = create_axis_at_location(fig, 0.5, 0.5, 2.25, 0.55)
    lim = np.amax(abs(z_lds[start:stop]))
    ax4 = fig.add_subplot(212)
    im4 = plt.imshow(np.kron(z_lds[start:stop, :].T, np.ones((50, 50))),
                     interpolation="none",
                     vmin=-lim,
                     vmax=lim,
                     cmap="RdBu",
                     extent=(0, stop - start, D + 1, 1))
    ax4.set_xlabel("Position $t$")
    ax4.set_yticks(0.5 + np.arange(1, D + 1))
    ax4.set_yticklabels(np.arange(1, D + 1))
    ax4.set_ylabel("$d$")

    ax4.set_title("Latent state sequence")

    # Add a colorbar
    from mpl_toolkits.axes_grid1 import make_axes_locatable
    divider = make_axes_locatable(ax4)
    cax = divider.append_axes("right", size="3%", pad=0.05)
    # cbar_ticks = np.round(np.linspace(-lim, lim, 3))
    cbar_ticks = [-4, 0, 4]
    cbar = plt.colorbar(im4, cax=cax, ticks=cbar_ticks)
    # cbar.set_label("Probability", labelpad=10)

    # plt.subplots_adjust(top=0.9)
    # plt.tight_layout(pad=0.2)
    plt.savefig("dna_lds_1.png")
    plt.savefig("dna_lds_1.pdf")
    plt.show()
예제 #23
0
파일: dna_lds.py 프로젝트: fivejjs/pgmult
def plot_qualitative_results(X, key, psi_lds, z_lds):
    start = 50
    stop = 70

    # Get the corresponding protein labels
    import operator
    id_to_char = dict([(v,k) for k,v in key.items()])
    sorted_chars = [idc[1].upper() for idc in sorted(id_to_char.items(), key=operator.itemgetter(0))]
    X_inds = np.where(X)[1]
    prot_str = [id_to_char[v].upper() for v in X_inds]


    from pgmult.utils import psi_to_pi
    pi_lds = psi_to_pi(psi_lds)

    # Plot the true and inferred states
    fig = create_figure(figsize=(3., 3.1))

    # Plot the string of protein labels
    # ax1 = create_axis_at_location(fig, 0.5, 2.5, 2.25, 0.25)
    # for n in xrange(start, stop):
    #     ax1.text(n, 0.5, prot_str[n].upper())
    # # ax1.get_xaxis().set_visible(False)
    # ax1.axis("off")
    # ax1.set_xlim([start-1,stop])
    # ax1.set_title("Protein Sequence")

    # ax2 = create_axis_at_location(fig, 0.5, 2.25, 2.25, 0.5)
    # ax2 = fig.add_subplot(311)
    # plt.imshow(X[start:stop,:].T, interpolation="none", vmin=0, vmax=1, cmap="Blues", aspect="auto")
    # ax2.set_title("One-hot Encoding")

    # ax3 = create_axis_at_location(fig, 0.5, 1.25, 2.25, 0.5)
    ax3 = fig.add_subplot(211)
    im3 = plt.imshow(np.kron(pi_lds[start:stop,:].T, np.ones((50,50))),
                             interpolation="none", vmin=0, vmax=1, cmap="Blues", aspect="auto",
               extent=(0,stop-start,K+1,1))
    # Circle true symbol
    from matplotlib.patches import Rectangle
    for n in xrange(start, stop):
        ax3.add_patch(Rectangle((n-start, X_inds[n]+1), 1, 1, facecolor="none", edgecolor="k"))

    # Print protein labels on y axis
    # ax3.set_yticks(np.arange(K))
    # ax3.set_yticklabels(sorted_chars)

    # Print protein sequence as xticks
    ax3.set_xticks(0.5+np.arange(0, stop-start))
    ax3.set_xticklabels(prot_str[start:stop])
    ax3.xaxis.tick_top()
    ax3.xaxis.set_tick_params(width=0)

    ax3.set_yticks(0.5+np.arange(1,K+1, 5))
    ax3.set_yticklabels(np.arange(1,K+1, 5))
    ax3.set_ylabel("$k$")

    ax3.set_title("Inferred Protein Probability", y=1.25)

    # Add a colorbar
    from mpl_toolkits.axes_grid1 import make_axes_locatable
    divider = make_axes_locatable(ax3)
    cax = divider.append_axes("right", size="3%", pad=0.05)
    cbar = plt.colorbar(im3, cax=cax, ticks=[0, 0.25, 0.5, 0.75, 1])
    cbar.set_label("Probability", labelpad=10)


    # ax4 = create_axis_at_location(fig, 0.5, 0.5, 2.25, 0.55)
    lim = np.amax(abs(z_lds[start:stop]))
    ax4 = fig.add_subplot(212)
    im4 = plt.imshow(np.kron(z_lds[start:stop, :].T, np.ones((50,50))),
                     interpolation="none", vmin=-lim, vmax=lim, cmap="RdBu",
                     extent=(0,stop-start, D+1,1))
    ax4.set_xlabel("Position $t$")
    ax4.set_yticks(0.5+np.arange(1,D+1))
    ax4.set_yticklabels(np.arange(1,D+1))
    ax4.set_ylabel("$d$")

    ax4.set_title("Latent state sequence")

    # Add a colorbar
    from mpl_toolkits.axes_grid1 import make_axes_locatable
    divider = make_axes_locatable(ax4)
    cax = divider.append_axes("right", size="3%", pad=0.05)
    # cbar_ticks = np.round(np.linspace(-lim, lim, 3))
    cbar_ticks = [-4, 0, 4]
    cbar = plt.colorbar(im4, cax=cax,  ticks=cbar_ticks)
    # cbar.set_label("Probability", labelpad=10)


    # plt.subplots_adjust(top=0.9)
    # plt.tight_layout(pad=0.2)
    plt.savefig("dna_lds_1.png")
    plt.savefig("dna_lds_1.pdf")
    plt.show()
예제 #24
0
def plot_mean_and_pca_locations(result):
    ### Plot the sampled locations for a few neurons
    _, _, _, _, Ls = result
    Ls_rot = []
    for L in Ls:
        R = compute_optimal_rotation(L, pfs, scale=False)
        Ls_rot.append(L.dot(R))
    Ls_rot = np.array(Ls_rot)
    Ls_mean = np.mean(Ls_rot, 0)

    # Bin the data
    from pyhawkes.utils.utils import convert_continuous_to_discrete
    S_dt = convert_continuous_to_discrete(S, C, 0.25, 0, T)

    # Smooth the data to get a firing rate
    from scipy.ndimage.filters import gaussian_filter1d
    S_smooth = np.array([gaussian_filter1d(s, 4) for s in S_dt.T]).T

    # Run pca to gte an embedding
    from sklearn.decomposition import PCA
    pca = PCA(n_components=2)
    pca.fit(S_smooth)
    Z = pca.components_.T

    # Rotate
    R = compute_optimal_rotation(Z, pfs, scale=False)
    Z = Z.dot(R)

    wheel_cmap = gradient_cmap([colors[0], colors[3], colors[2], colors[1], colors[0]])
    fig = create_figure(figsize=(1.4,2.9))
    # plt.subplot(211, aspect='equal')
    ax = create_axis_at_location(fig, .3, 1.7, 1, 1)


    for i,k in enumerate(node_perm):
        color = wheel_cmap((np.pi+pfs_th[k])/(2*np.pi))
        plt.plot(Ls_mean[k,0], Ls_mean[k, 1], 'o',
                 markerfacecolor=color, markeredgecolor=color,
                 markersize=4 + 4 * pf_size[k],
                 alpha=0.7)

    #     plt.gca().add_patch(Circle((0,0), radius=rad, ec='k', fc="none"))
    plt.title("Mean Locations")
    plt.xlim(-3, 3)
    plt.xticks([-2, 0, 2])
    plt.ylim(-3, 3)
    plt.yticks([-2, 0, 2])


    # plt.subplot(212, aspect='equal')
    ax = create_axis_at_location(fig, .3, .2, 1, 1)

    for i,k in enumerate(node_perm):
        color = wheel_cmap((np.pi+pfs_th[k])/(2*np.pi))
        plt.plot(Z[k,0], Z[k, 1], 'o',
                 markerfacecolor=color, markeredgecolor=color,
                 markersize=4 + 4 * pf_size[k],
                 alpha=0.7)

    #     plt.gca().add_patch(Circle((0,0), radius=rad, ec='k', fc="none"))
    plt.title("PCA Locations")
    plt.xlim(-.5, .5)
    # plt.xlabel("$x$")
    plt.xticks([-.4, 0, .4])
    plt.ylim(-.5, .5)
    # plt.ylabel("$y$")
    plt.yticks([-.4, 0, .4])

    # plt.tight_layout()
    plt.savefig(os.path.join(results_dir, "hipp_mean_pca_locations.pdf"))
    plt.show()
예제 #25
0
def draw_mixture_figure(Ns,
                        Ss,
                        z,
                        lmbda,
                        filename="figure1.png",
                        saveargs=dict(dpi=300)):
    fig = create_figure((5.5, 2.7))
    ax = create_axis_at_location(fig, .75, .5, 4., 1.375)
    ymax = 105
    # Plot the rates
    for i in range(n):
        ax.add_patch(
            Rectangle([i * D, 0],
                      D,
                      lmbda[z[i]],
                      color=colors[z[i]],
                      ec="none",
                      alpha=0.5))
        ax.plot([i * D, (i + 1) * D], lmbda[z[i]] * np.ones(2), '-k', lw=2)

        if i < n - 1:
            ax.plot([(i + 1) * D, (i + 1) * D], [lmbda[z[i]], lmbda[z[i + 1]]],
                    '-k',
                    lw=2)

        # Plot boundaries
        ax.plot([(i + 1) * D, (i + 1) * D], [0, ymax], ':k', lw=1)

    # Plot x axis
    plt.plot([0, T], [0, 0], '-k', lw=2)

    # Plot spike times
    for s in Ss:
        plt.plot([s, s], [0, 60], '-ko', markerfacecolor='k', markersize=5)

    plt.xlabel("time [ms]")
    plt.ylabel("firing rate [Hz]")
    plt.xlim(0, T)
    plt.ylim(-5, ymax)

    ## Now plot the spike count above
    ax = create_axis_at_location(fig, .75, 2., 4., .25)
    for i in xrange(n):
        # Plot boundaries
        ax.plot([(i + 1) * D, (i + 1) * D], [0, 10], '-k', lw=1)
        ax.text(i * D + D / 3.5, 3, "%d" % Ns[i], fontdict={"size": 9})
    ax.set_xlim(0, T)
    ax.set_ylim(0, 10)
    ax.set_xticks([])
    ax.set_yticks([])
    ax.yaxis.labelpad = 30
    ax.set_ylabel("${s_t}$", rotation=0, verticalalignment='center')

    ## Now plot the latent state above that above
    ax = create_axis_at_location(fig, .75, 2.375, 4., .25)
    for i in xrange(n):
        # Plot boundaries
        ax.add_patch(
            Rectangle([i * D, 0],
                      D,
                      10,
                      color=colors[z[i]],
                      ec="none",
                      alpha=0.5))

        ax.plot([(i + 1) * D, (i + 1) * D], [0, 10], '-k', lw=1)
        ax.text(i * D + D / 3.5,
                3,
                "u" if z[i] == 0 else "d",
                fontdict={"size": 9})
    ax.set_xlim(0, T)
    ax.set_ylim(0, 10)
    ax.set_xticks([])
    ax.set_yticks([])
    ax.yaxis.labelpad = 30
    ax.set_ylabel("${z_t}$", rotation=0, verticalalignment='center')

    #fig.savefig(filename + ".pdf")
    fig.savefig(filename, **saveargs)

    plt.close(fig)
예제 #26
0
def plot_pred_ll_vs_time(plls, timestamps, Z=1.0, T_train=None, nbins=4):

    # import seaborn as sns
    # sns.set(style="whitegrid")

    from hips.plotting.layout import create_figure
    from hips.plotting.colormaps import harvard_colors

    # Make the ICML figure
    fig = create_figure((4,3))
    ax = fig.add_subplot(111)
    col = harvard_colors()
    plt.grid()

    # Compute the max and min time in seconds
    print "Homog PLL: ", plls['homog']
    # DEBUG
    plls['homog'] = 0.0
    Z = 1.0

    assert "bfgs" in plls and "bfgs" in timestamps
    # t_bfgs = timestamps["bfgs"]
    t_bfgs = 1.0
    t_start = 1.0
    t_stop = 0.0

    if 'svi' in plls and 'svi' in timestamps:
        isreal = ~np.isnan(plls['svi'])
        svis = plls['svi'][isreal]
        t_svi = timestamps['svi'][isreal]
        t_svi = t_bfgs + t_svi - t_svi[0]
        t_stop = max(t_stop, t_svi[-1])
        ax.semilogx(t_svi, (svis - plls['homog'])/Z, color=col[0], label="SVI", lw=1.5)

    if 'vb' in plls and 'vb' in timestamps:
        t_vb = timestamps['vb']
        t_vb = t_bfgs + t_vb
        t_stop = max(t_stop, t_vb[-1])
        ax.semilogx(t_vb, (plls['vb'] - plls['homog'])/Z, color=col[1], label="VB", lw=1.5)

    if 'gibbs' in plls and 'gibbs' in timestamps:
        t_gibbs = timestamps['gibbs']
        t_gibbs = t_bfgs + t_gibbs
        t_stop = max(t_stop, t_gibbs[-1])
        ax.semilogx(t_gibbs, (plls['gibbs'] - plls['homog'])/Z, color=col[2], label="Gibbs", lw=1.5)

    # if 'gibbs_ss' in plls and 'gibbs_ss' in timestamps:
    #     t_gibbs = timestamps['gibbs_ss']
    #     t_gibbs = t_bfgs + t_gibbs
    #     t_stop = max(t_stop, t_gibbs[-1])
    #     ax.semilogx(t_gibbs, (plls['gibbs_ss'] - plls['homog'])/Z, color=col[8], label="Gibbs-SS", lw=1.5)

    # Extend lines to t_st
    if 'svi' in plls and 'svi' in timestamps:
        final_svi_pll = -np.log(4) + logsumexp(svis[-4:])
        ax.semilogx([t_svi[-1], t_stop],
                    [(final_svi_pll - plls['homog'])/Z,
                     (final_svi_pll - plls['homog'])/Z],
                    '--',
                    color=col[0], lw=1.5)

    if 'vb' in plls and 'vb' in timestamps:
        ax.semilogx([t_vb[-1], t_stop],
                    [(plls['vb'][-1] - plls['homog'])/Z,
                     (plls['vb'][-1] - plls['homog'])/Z],
                    '--',
                    color=col[1], lw=1.5)

    ax.semilogx([t_start, t_stop],
                [(plls['bfgs'] - plls['homog'])/Z, (plls['bfgs'] - plls['homog'])/Z],
                color=col[3], lw=1.5, label="MAP" )

    # Put a legend above
    plt.legend(bbox_to_anchor=(0., 1.02, 1., .102), loc=3,
               ncol=5, mode="expand", borderaxespad=0.,
               prop={'size':9})

    ax.set_xlim(t_start, t_stop)

    # Format the ticks
    # plt.locator_params(nbins=nbins)

    import matplotlib.ticker as ticker
    logxscale = 3
    xticks = ticker.FuncFormatter(lambda x, pos: '{0:.2f}'.format(x/10.**logxscale))
    ax.xaxis.set_major_formatter(xticks)
    ax.set_xlabel('Time ($10^{%d}$ s)' % logxscale)

    logyscale = 4
    yticks = ticker.FuncFormatter(lambda y, pos: '{0:.3f}'.format(y/10.**logyscale))
    ax.yaxis.set_major_formatter(yticks)
    ax.set_ylabel('Pred. LL ($ \\times 10^{%d}$)' % logyscale)

    # ylim = ax.get_ylim()
    # ax.plot([t_bfgs, t_bfgs], ylim, '--k')
    # ax.set_ylim(ylim)
    ylim = (-129980, -129840)
    ax.set_ylim(ylim)


    # plt.tight_layout()
    plt.subplots_adjust(bottom=0.2, left=0.2)
    # plt.title("Predictive Log Likelihood ($T=%d$)" % T_train)
    plt.show()
    fig.savefig('figure2b.pdf')
예제 #27
0
def plot_pca_locations():
    ### Plot the sampled locations for a few neurons

    # Bin the data
    from pyhawkes.utils.utils import convert_continuous_to_discrete
    S_dt = convert_continuous_to_discrete(S, C, 0.25, 0, T)

    # Smooth the data to get a firing rate
    from scipy.ndimage.filters import gaussian_filter1d
    S_smooth = np.array([gaussian_filter1d(s, 4) for s in S_dt.T]).T

    # Run pca to gte an embedding
    from sklearn.decomposition import PCA
    pca = PCA(n_components=2)
    pca.fit(S_smooth)
    Z = pca.components_.T

    # Rotate
    R = compute_optimal_rotation(Z, pfs)
    Z = Z.dot(R)

    fig = create_figure(figsize=(1.4,2.5))
    plt.subplot(211, aspect='equal')

    wheel_cmap = gradient_cmap([colors[0], colors[3], colors[2], colors[1], colors[0]])

    for i,k in enumerate(node_perm):
        color = wheel_cmap((np.pi+pfs_th[k])/(2*np.pi))
        plt.plot(pfs[k,0], pfs[k, 1], 'o',
                 markerfacecolor=color, markeredgecolor=color,
                 markersize=4 + 4 * pf_size[k],
                 alpha=0.7)

    #     plt.gca().add_patch(Circle((0,0), radius=rad, ec='k', fc="none"))
    plt.title("True Place Fields")
    plt.xlim(-45, 45)
    # plt.xlabel("$x$")
    plt.xticks([-40, -20, 0, 20, 40], [])
    plt.ylim(-45, 45)
    # plt.ylabel("$y$")
    plt.yticks([-40, -20, 0, 20, 40], [])

    # Now plot the inferred locations


    plt.subplot(212, aspect='equal')

    for i,k in enumerate(node_perm):
        color = wheel_cmap((np.pi+pfs_th[k])/(2*np.pi))
        plt.plot(Z[k,0], Z[k, 1], 'o',
                 markerfacecolor=color, markeredgecolor=color,
                 markersize=4 + 4 * pf_size[k],
                 alpha=0.7)

    #     plt.gca().add_patch(Circle((0,0), radius=rad, ec='k', fc="none"))
    plt.title("PCA Locations")
    plt.xlim(-25, 25)
    # plt.xlabel("$x$")
    plt.xticks([-20, 0, 20], [])
    plt.ylim(-25, 25)
    # plt.ylabel("$y$")
    plt.yticks([-20, 0, 20], [])

    plt.tight_layout()
    plt.savefig(os.path.join(results_dir, "hipp_pca_locations.pdf"))
    plt.show()
예제 #28
0
def plot_locations(result, offset=0):
    ### Plot the sampled locations for a few neurons
    _, _, _, _, Ls = result
    Ls_rot = []
    for L in Ls:
        R = compute_optimal_rotation(L, pfs, scale=False)
        Ls_rot.append(L.dot(R))
    Ls_rot = np.array(Ls_rot)

    fig = create_figure(figsize=(1.4,2.9))
    ax = create_axis_at_location(fig, .3, 1.7, 1, 1)

    # toplot = np.random.choice(np.arange(K), size=4, replace=False)
    toplot = np.linspace(offset,K+offset, 4, endpoint=False).astype(np.int)
    print toplot
    wheel_cmap = gradient_cmap([colors[0], colors[3], colors[2], colors[1], colors[0]])
    plot_colors = [wheel_cmap((np.pi+pfs_th[node_perm[j]])/(2*np.pi)) for j in toplot]

    for i,k in enumerate(node_perm):
        # plt.text(pfs[k,0], pfs[k,1], "%d" % i)
        if i not in toplot:
            color = 0.8 * np.ones(3)

            plt.plot(pfs[k,0], pfs[k, 1], 'o',
                     markerfacecolor=color, markeredgecolor=color,
                     markersize=4 + 4 * pf_size[k],
                     alpha=1.0)

    for i,k in enumerate(node_perm):
        # plt.text(pfs[k,0], pfs[k,1], "%d" % i)
        if i in toplot:
            j = np.where(toplot==i)[0][0]
            color = plot_colors[j]

            plt.plot(pfs[k,0], pfs[k, 1], 'o',
                     markerfacecolor=color, markeredgecolor=color,
                     markersize=4 + 4 * pf_size[k])



    #     plt.gca().add_patch(Circle((0,0), radius=rad, ec='k', fc="none"))
    plt.title("True Place Fields")
    plt.xlim(-45,45)
    plt.xticks([-40, -20, 0, 20, 40])
    # plt.xlabel("$x$")
    plt.ylim(-45,45)
    plt.yticks([-40, -20, 0, 20, 40])
    # plt.ylabel("$y$")

    # Now plot the inferred locations
    # plt.subplot(212, aspect='equal')
    ax = create_axis_at_location(fig, .3, .2, 1, 1)
    for L in Ls_rot[::2]:
        for j in np.random.permutation(len(toplot)):
            k = node_perm[toplot][j]
            color = plot_colors[j]
            plt.plot(L[k,0], L[k,1], 'o',
                     markerfacecolor=color, markeredgecolor="none",
                     markersize=4, alpha=0.25)

    plt.title("Locations Samples")
    # plt.xlim(-30, 30)
    # plt.xticks([])
    # plt.ylim(-30, 30)
    # plt.yticks([])
    plt.xlim(-3, 3)
    plt.xticks([-2, 0, 2])
    plt.ylim(-3, 3)
    plt.yticks([-2, 0, 2])

    plt.savefig(os.path.join(results_dir, "locations_%d.pdf" % offset))
예제 #29
0
def plot_results(result):
    lls, plls, Weffs, Ps, Ls = result

    ### Colored locations
    wheel_cmap = gradient_cmap([colors[0], colors[3], colors[2], colors[1], colors[0]])
    fig = create_figure(figsize=(1.8, 1.8))
    # ax = create_axis_at_location(fig, .1, .1, 1.5, 1.5, box=False)
    ax = create_axis_at_location(fig, .6, .4, 1.1, 1.1)

    for i,k in enumerate(node_perm):
        color = wheel_cmap((np.pi+pfs_th[k])/(2*np.pi))
        # alpha = pfs_rad[k] / 47
        alpha = 0.7
        ax.add_patch(Circle((pfs[k,0], pfs[k,1]),
                            radius=3+4*pf_size[k],
                            color=color, ec="none",
                            alpha=alpha)
                            )

    plt.title("True place fields")
    # ax.text(0, 45, "True Place Fields",
    #         horizontalalignment="center",
    #         fontdict=dict(size=9))
    plt.xlim(-45,45)
    plt.xticks([-40, -20, 0, 20, 40])
    plt.xlabel("$x$ [cm]")
    plt.ylim(-45,45)
    plt.yticks([-40, -20, 0, 20, 40])
    plt.ylabel("$y$ [cm]")
    plt.savefig(os.path.join(results_dir, "hipp_colored_locations.pdf"))


    # Plot the inferred weighted adjacency matrix
    fig = create_figure(figsize=(1.8, 1.8))
    ax = create_axis_at_location(fig, .4, .4, 1.1, 1.1)

    Weff = np.array(Weffs[N_samples//2:]).mean(0)
    Weff = Weff[np.ix_(node_perm, node_perm)]
    lim = Weff[(1-np.eye(K)).astype(np.bool)].max()
    im = ax.imshow(np.kron(Weff, np.ones((20,20))),
                   interpolation="none", cmap="Greys", vmax=lim)
    ax.set_xticks([])
    ax.set_yticks([])

    # node_colors = wheel_cmap()
    node_values = ((np.pi+pfs_th[node_perm])/(2*np.pi))[:,None] *np.ones((K,2))
    yax = create_axis_at_location(fig, .2, .4, .3, 1.1)
    remove_plot_labels(yax)
    yax.imshow(node_values, interpolation="none",
               cmap=wheel_cmap)
    yax.set_xticks([])
    yax.set_yticks([])
    yax.set_ylabel("pre")

    xax = create_axis_at_location(fig, .4, .2, 1.1, .3)
    remove_plot_labels(xax)
    xax.imshow(node_values.T, interpolation="none",
               cmap=wheel_cmap)
    xax.set_xticks([])
    xax.set_yticks([])
    xax.set_xlabel("post")

    cbax = create_axis_at_location(fig, 1.55, .4, .04, 1.1)
    plt.colorbar(im, cax=cbax, ticks=[0, .1, .2,  .3])
    cbax.tick_params(labelsize=8, pad=1)
    cbax.set_ticklabels=["0", ".1", ".2",  ".3"]

    ax.set_title("Inferred Weights")
    plt.savefig(os.path.join(results_dir, "hipp_W.pdf"))

    # # Plot the inferred connection probability
    # plt.figure()
    # plt.imshow(P, interpolation="none", cmap="Greys", vmin=0)
    # plt.colorbar()

        # Plot the inferred weighted adjacency matrix
    fig = create_figure(figsize=(1.8, 1.8))
    ax = create_axis_at_location(fig, .4, .4, 1.1, 1.1)

    P = np.array(Ps[N_samples//2:]).mean(0)
    P = P[np.ix_(node_perm, node_perm)]
    im = ax.imshow(np.kron(P, np.ones((20,20))),
                   interpolation="none", cmap="Greys", vmin=0, vmax=1)
    ax.set_xticks([])
    ax.set_yticks([])

    # node_colors = wheel_cmap()
    node_values = ((np.pi+pfs_th[node_perm])/(2*np.pi))[:,None] *np.ones((K,2))
    yax = create_axis_at_location(fig, .2, .4, .3, 1.1)
    remove_plot_labels(yax)
    yax.imshow(node_values, interpolation="none",
               cmap=wheel_cmap)
    yax.set_xticks([])
    yax.set_yticks([])
    yax.set_ylabel("pre")

    xax = create_axis_at_location(fig, .4, .2, 1.1, .3)
    remove_plot_labels(xax)
    xax.imshow(node_values.T, interpolation="none",
               cmap=wheel_cmap)
    xax.set_xticks([])
    xax.set_yticks([])
    xax.set_xlabel("post")

    cbax = create_axis_at_location(fig, 1.55, .4, .04, 1.1)
    plt.colorbar(im, cax=cbax, ticks=[0, .5, 1])
    cbax.tick_params(labelsize=8, pad=1)
    cbax.set_ticklabels=["0.0", "0.5",  "1.0"]

    ax.set_title("Inferred Probability")
    plt.savefig(os.path.join(results_dir, "hipp_P.pdf"))


    plt.show()
예제 #30
0
def plot_census_results(train, samples, test, test_pis):
    # Extract samp[les
    train_mus = np.array([s[0] for s in samples])
    train_psis = np.array([s[1][0][0] for s in samples])
    # omegas = np.array([s[1][0][1] for s in samples])

    # Adjust psis by the mean and compute the inferred pis
    train_psis += train_mus[0][None,None,:]
    train_pis = np.array([psi_to_pi(psi_sample) for psi_sample in train_psis])
    train_pi_mean = np.mean(train_pis, axis=0)
    train_pi_std = np.std(train_pis, axis=0)

    # Compute test pi mean and std
    test_pi_mean = np.mean(test_pis, axis=0)
    test_pi_std = np.std(test_pis, axis=0)

    # Compute empirical probabilities
    train_pi_emp = train.data / train.data.sum(axis=1)[:,None]
    test_pi_emp = test.data / test.data.sum(axis=1)[:,None]


    # Plot the temporal trajectories for a few names
    names = ["Scott", "Matthew", "Ethan"]
    states = ["NY", "TX", "WA"]
    linestyles = ["-", "--", ":"]

    fig = create_figure(figsize=(3., 3))
    ax1 = create_axis_at_location(fig, 0.6, 0.5, 2.25, 1.75)
    for name, color in zip(names, colors):
        for state, linestyle in zip(states, linestyles):
            train_state_inds = (train.states == state)
            train_name_ind = np.array(train.names) == name.lower()
            train_years = train.years[train.states == state]
            train_mean_name = train_pi_mean[train_state_inds, train_name_ind]
            train_std_name = train_pi_std[train_state_inds, train_name_ind]

            test_state_inds = (test.states == state)
            test_name_ind = np.array(test.names) == name.lower()
            test_years = test.years[test.states == state]
            test_mean_name = test_pi_mean[test_state_inds, test_name_ind]
            test_std_name = test_pi_std[test_state_inds, test_name_ind]

            years = np.concatenate((train_years, test_years))
            mean_name = np.concatenate((train_mean_name, test_mean_name))
            std_name = np.concatenate((train_std_name, test_std_name))

            # Sausage plot
            sausage_plot(years, mean_name, std_name,
                         color=color, alpha=0.5)

            # Plot inferred mean
            plt.plot(years, mean_name,
                     color=color, label="%s, %s" % (name, state),
                     ls=linestyle, lw=2)

            # Plot empirical probabilities
            plt.plot(train.years[train_state_inds],
                     train_pi_emp[train_state_inds, train_name_ind],
                     color=color,
                     ls="", marker="x", markersize=4)

            plt.plot(test.years[test_state_inds],
                     test_pi_emp[test_state_inds, test_name_ind],
                     color=color,
                     ls="", marker="x", markersize=4)

    # Plot a vertical line to divide train and test
    ylim = plt.gca().get_ylim()
    plt.plot((test.years.min()-0.5) * np.ones(2), ylim, ':k', lw=0.5)
    plt.ylim(ylim)

    # plt.legend(loc="outside right")
    plt.legend(bbox_to_anchor=(0., 1.05, 1., .105), loc=3,
               ncol=len(names), mode="expand", borderaxespad=0.,
               fontsize="x-small")

    plt.xlabel("Year")
    plt.xlim(train.years.min(), test.years.max()+0.1)
    plt.ylabel("Probability")

    # plt.tight_layout()
    fig.savefig("census_gp_rates.pdf")

    plt.show()
    plt.pause(0.1)
예제 #31
0
def plot_pred_ll_vs_time(models,
                         results,
                         burnin=0,
                         homog_ll=np.nan,
                         std_ll=np.nan,
                         nlin_ll=np.nan,
                         true_ll=np.nan,
                         baseline=0,
                         normalizer=1,
                         output_dir=".",
                         xlim=None,
                         ylim=None,
                         title=None):
    from hips.plotting.layout import create_figure, create_axis_at_location
    from hips.plotting.colormaps import harvard_colors

    # Make the ICML figure
    fig = create_figure((3, 2))
    ax = create_axis_at_location(fig, 0.7, 0.4, 2.25, 1.35)
    col = harvard_colors()
    ax.grid()

    t_start = 0
    t_stop = 0

    def standardize(x):
        return (x - baseline) / normalizer

    for i, (model, result) in enumerate(zip(models, results)):
        ax.plot(result.timestamps[burnin:],
                standardize(result.test_lls[burnin:]),
                lw=2,
                color=col[i],
                label=model)

        # Update time limits
        t_start = min(t_start, result.timestamps[burnin:].min())
        t_stop = max(t_stop, result.timestamps[burnin:].max())

    if xlim is not None:
        t_start = xlim[0]
        t_stop = xlim[1]

    # plt.legend(loc="outside right")

    # Plot baselines
    ax.plot([t_start, t_stop],
            standardize(homog_ll) * np.ones(2),
            lw=2,
            color='k',
            label="Homog")

    # Plot the standard Hawkes test ll
    ax.plot([t_start, t_stop],
            standardize(std_ll) * np.ones(2),
            lw=2,
            color=col[len(models)],
            label="Std.")

    # Plot the Nonlinear Hawkes test ll
    ax.plot([t_start, t_stop],
            standardize(nlin_ll) * np.ones(2),
            lw=2,
            ls='--',
            color=col[len(models) + 1],
            label="Nonlinear")

    # Plot the true ll
    ax.plot([t_start, t_stop],
            standardize(true_ll) * np.ones(2),
            '--k',
            lw=2,
            label="True")

    ax.set_xlabel("time [sec]")
    ax.set_ylabel("Pred. LL. [nats/event]")

    ax.set_xscale("log")
    if xlim is not None:
        ax.set_xlim(xlim)
    else:
        ax.set_xlim(t_start, t_stop)
    if ylim is not None:
        ax.set_ylim(ylim)

    if title is not None:
        ax.set_title(title)

    output_file = os.path.join(output_dir, "pred_ll_vs_time.pdf")
    fig.savefig(output_file)
    plt.show()
예제 #32
0
def plot_locations(result, offset=0):
    ### Plot the sampled locations for a few neurons
    _, _, _, _, Ls = result
    Ls_rot = []
    for L in Ls:
        R = compute_optimal_rotation(L, pfs, scale=False)
        Ls_rot.append(L.dot(R))
    Ls_rot = np.array(Ls_rot)

    fig = create_figure(figsize=(1.4,2.9))
    ax = create_axis_at_location(fig, .3, 1.7, 1, 1)

    # toplot = np.random.choice(np.arange(K), size=4, replace=False)
    toplot = np.linspace(offset,K+offset, 4, endpoint=False).astype(np.int)
    print(toplot)
    wheel_cmap = gradient_cmap([colors[0], colors[3], colors[2], colors[1], colors[0]])
    plot_colors = [wheel_cmap((np.pi+pfs_th[node_perm[j]])/(2*np.pi)) for j in toplot]

    for i,k in enumerate(node_perm):
        # plt.text(pfs[k,0], pfs[k,1], "%d" % i)
        if i not in toplot:
            color = 0.8 * np.ones(3)

            plt.plot(pfs[k,0], pfs[k, 1], 'o',
                     markerfacecolor=color, markeredgecolor=color,
                     markersize=4 + 4 * pf_size[k],
                     alpha=1.0)

    for i,k in enumerate(node_perm):
        # plt.text(pfs[k,0], pfs[k,1], "%d" % i)
        if i in toplot:
            j = np.where(toplot==i)[0][0]
            color = plot_colors[j]

            plt.plot(pfs[k,0], pfs[k, 1], 'o',
                     markerfacecolor=color, markeredgecolor=color,
                     markersize=4 + 4 * pf_size[k])



    #     plt.gca().add_patch(Circle((0,0), radius=rad, ec='k', fc="none"))
    plt.title("True Place Fields")
    plt.xlim(-45,45)
    plt.xticks([-40, -20, 0, 20, 40])
    # plt.xlabel("$x$")
    plt.ylim(-45,45)
    plt.yticks([-40, -20, 0, 20, 40])
    # plt.ylabel("$y$")

    # Now plot the inferred locations
    # plt.subplot(212, aspect='equal')
    ax = create_axis_at_location(fig, .3, .2, 1, 1)
    for L in Ls_rot[::2]:
        for j in np.random.permutation(len(toplot)):
            k = node_perm[toplot][j]
            color = plot_colors[j]
            plt.plot(L[k,0], L[k,1], 'o',
                     markerfacecolor=color, markeredgecolor="none",
                     markersize=4, alpha=0.25)

    plt.title("Locations Samples")
    # plt.xlim(-30, 30)
    # plt.xticks([])
    # plt.ylim(-30, 30)
    # plt.yticks([])
    plt.xlim(-3, 3)
    plt.xticks([-2, 0, 2])
    plt.ylim(-3, 3)
    plt.yticks([-2, 0, 2])

    plt.savefig(os.path.join(results_dir, "locations_%d.pdf" % offset))
예제 #33
0
def plot_pred_ll_vs_D(all_results,
                      Ds,
                      Xtrain,
                      Xtest,
                      results_dir,
                      models=None):
    # Create a big matrix of shape (len(Ds) x 5 x T) for the pred lls
    N = len(Ds)  # Number of dimensions tests
    M = len(all_results[0])  # Number of models tested
    T = len(all_results[0][0].pred_lls)  # Number of MCMC iters
    pred_lls = np.zeros((N, M, T))
    for n in xrange(N):
        for m in xrange(M):
            if all_results[n][m].pred_lls.ndim == 2:
                pred_lls[n, m] = all_results[n][m].pred_lls[:, 0]
            else:
                pred_lls[n, m] = all_results[n][m].pred_lls

    # Compute the mean and standard deviation on burned in samples
    burnin = T // 2
    pred_ll_mean = logsumexp(pred_lls[:, :, burnin:],
                             axis=-1) - np.log(T - burnin)

    # Use bootstrap to compute error bars
    pred_ll_std = np.zeros_like(pred_ll_mean)
    for n in xrange(N):
        for m in xrange(M):
            samples = np.random.choice(pred_lls[n, m, burnin:],
                                       size=(100, (T - burnin)),
                                       replace=True)
            pll_samples = logsumexp(samples, axis=1) - np.log(T - burnin)
            pred_ll_std[n, m] = pll_samples.std()

    # Get the baseline pred ll
    baseline = 0
    normalizer = 0
    for Xtr, Xte in zip(Xtrain, Xtest):
        pi_emp = Xtr.sum(0) / float(Xtr.sum())
        pi_emp = np.clip(pi_emp, 1e-8, np.inf)
        pi_emp /= pi_emp.sum()
        baseline += Multinomial(weights=pi_emp,
                                K=Xtr.shape[1]).log_likelihood(Xte).sum()
        normalizer += Xte.sum()

    # Make a bar chart with errorbars
    from hips.plotting.layout import create_figure
    fig = create_figure(figsize=(1.25, 2.5), transparent=True)
    fig.set_tight_layout(True)
    ax = fig.add_subplot(111)
    ax.spines['top'].set_visible(False)
    ax.spines['right'].set_visible(False)
    ax.get_xaxis().tick_bottom()
    ax.get_yaxis().tick_left()

    width = np.min(np.diff(Ds)) / (M + 1.0) if len(Ds) > 1 else 1.
    for m in xrange(M):
        ax.bar(Ds + m * width, (pred_ll_mean[:, m] - baseline) / normalizer,
               yerr=pred_ll_std[:, m] / normalizer,
               width=0.9 * width,
               color=colors[m],
               ecolor='k')
        #
        # ax.text(Ds+(m-1)*width, yloc, rankStr, horizontalalignment=align,
        #     verticalalignment='center', color=clr, weight='bold')

    # Plot the zero line
    ax.plot([Ds.min() - width, Ds.max() + (M + 1) * width], np.zeros(2), '-k')

    # Set the tick labels
    ax.set_xlim(Ds.min() - width, Ds.max() + (M + 1) * width)
    # ax.set_xticks(Ds + (M*width)/2.)
    # ax.set_xticklabels(Ds)
    # ax.set_xticks(Ds + width * np.arange(M) + width/2. )
    # ax.set_xticklabels(models, rotation=45)
    ax.set_xticks([])

    # ax.set_xlabel("D")
    ax.set_ylabel("Pred. Log Lkhd. (nats/protein)")
    ax.set_title("DNA")

    plt.savefig(os.path.join(results_dir, "pred_ll_vs_D.pdf"))
예제 #34
0
def plot_pca_locations():
    ### Plot the sampled locations for a few neurons

    # Bin the data
    from pyhawkes.utils.utils import convert_continuous_to_discrete
    S_dt = convert_continuous_to_discrete(S, C, 0.25, 0, T)

    # Smooth the data to get a firing rate
    from scipy.ndimage.filters import gaussian_filter1d
    S_smooth = np.array([gaussian_filter1d(s, 4) for s in S_dt.T]).T

    # Run pca to gte an embedding
    from sklearn.decomposition import PCA
    pca = PCA(n_components=2)
    pca.fit(S_smooth)
    Z = pca.components_.T

    # Rotate
    R = compute_optimal_rotation(Z, pfs)
    Z = Z.dot(R)

    fig = create_figure(figsize=(1.4,2.5))
    plt.subplot(211, aspect='equal')

    wheel_cmap = gradient_cmap([colors[0], colors[3], colors[2], colors[1], colors[0]])

    for i,k in enumerate(node_perm):
        color = wheel_cmap((np.pi+pfs_th[k])/(2*np.pi))
        plt.plot(pfs[k,0], pfs[k, 1], 'o',
                 markerfacecolor=color, markeredgecolor=color,
                 markersize=4 + 4 * pf_size[k],
                 alpha=0.7)

    #     plt.gca().add_patch(Circle((0,0), radius=rad, ec='k', fc="none"))
    plt.title("True Place Fields")
    plt.xlim(-45, 45)
    # plt.xlabel("$x$")
    plt.xticks([-40, -20, 0, 20, 40], [])
    plt.ylim(-45, 45)
    # plt.ylabel("$y$")
    plt.yticks([-40, -20, 0, 20, 40], [])

    # Now plot the inferred locations


    plt.subplot(212, aspect='equal')

    for i,k in enumerate(node_perm):
        color = wheel_cmap((np.pi+pfs_th[k])/(2*np.pi))
        plt.plot(Z[k,0], Z[k, 1], 'o',
                 markerfacecolor=color, markeredgecolor=color,
                 markersize=4 + 4 * pf_size[k],
                 alpha=0.7)

    #     plt.gca().add_patch(Circle((0,0), radius=rad, ec='k', fc="none"))
    plt.title("PCA Locations")
    plt.xlim(-25, 25)
    # plt.xlabel("$x$")
    plt.xticks([-20, 0, 20], [])
    plt.ylim(-25, 25)
    # plt.ylabel("$y$")
    plt.yticks([-20, 0, 20], [])

    plt.tight_layout()
    plt.savefig(os.path.join(results_dir, "hipp_pca_locations.pdf"))
    plt.show()
def plot_pred_ll_vs_time(plls, timestamps, Z=1.0, T_train=None, nbins=4):

    # import seaborn as sns
    # sns.set(style="whitegrid")

    from hips.plotting.layout import create_figure
    from hips.plotting.colormaps import harvard_colors

    # Make the ICML figure
    fig = create_figure((4, 3))
    ax = fig.add_subplot(111)
    col = harvard_colors()
    plt.grid()

    # Compute the max and min time in seconds
    print("Homog PLL: ", plls['homog'])
    # DEBUG
    plls['homog'] = 0.0
    Z = 1.0

    assert "bfgs" in plls and "bfgs" in timestamps
    # t_bfgs = timestamps["bfgs"]
    t_bfgs = 1.0
    t_start = 1.0
    t_stop = 0.0

    if 'svi' in plls and 'svi' in timestamps:
        isreal = ~np.isnan(plls['svi'])
        svis = plls['svi'][isreal]
        t_svi = timestamps['svi'][isreal]
        t_svi = t_bfgs + t_svi - t_svi[0]
        t_stop = max(t_stop, t_svi[-1])
        ax.semilogx(t_svi, (svis - plls['homog']) / Z,
                    color=col[0],
                    label="SVI",
                    lw=1.5)

    if 'vb' in plls and 'vb' in timestamps:
        t_vb = timestamps['vb']
        t_vb = t_bfgs + t_vb
        t_stop = max(t_stop, t_vb[-1])
        ax.semilogx(t_vb, (plls['vb'] - plls['homog']) / Z,
                    color=col[1],
                    label="VB",
                    lw=1.5)

    if 'gibbs' in plls and 'gibbs' in timestamps:
        t_gibbs = timestamps['gibbs']
        t_gibbs = t_bfgs + t_gibbs
        t_stop = max(t_stop, t_gibbs[-1])
        ax.semilogx(t_gibbs, (plls['gibbs'] - plls['homog']) / Z,
                    color=col[2],
                    label="Gibbs",
                    lw=1.5)

    # if 'gibbs_ss' in plls and 'gibbs_ss' in timestamps:
    #     t_gibbs = timestamps['gibbs_ss']
    #     t_gibbs = t_bfgs + t_gibbs
    #     t_stop = max(t_stop, t_gibbs[-1])
    #     ax.semilogx(t_gibbs, (plls['gibbs_ss'] - plls['homog'])/Z, color=col[8], label="Gibbs-SS", lw=1.5)

    # Extend lines to t_st
    if 'svi' in plls and 'svi' in timestamps:
        final_svi_pll = -np.log(4) + logsumexp(svis[-4:])
        ax.semilogx([t_svi[-1], t_stop], [(final_svi_pll - plls['homog']) / Z,
                                          (final_svi_pll - plls['homog']) / Z],
                    '--',
                    color=col[0],
                    lw=1.5)

    if 'vb' in plls and 'vb' in timestamps:
        ax.semilogx([t_vb[-1], t_stop], [(plls['vb'][-1] - plls['homog']) / Z,
                                         (plls['vb'][-1] - plls['homog']) / Z],
                    '--',
                    color=col[1],
                    lw=1.5)

    ax.semilogx([t_start, t_stop], [(plls['bfgs'] - plls['homog']) / Z,
                                    (plls['bfgs'] - plls['homog']) / Z],
                color=col[3],
                lw=1.5,
                label="MAP")

    # Put a legend above
    plt.legend(bbox_to_anchor=(0., 1.02, 1., .102),
               loc=3,
               ncol=5,
               mode="expand",
               borderaxespad=0.,
               prop={'size': 9})

    ax.set_xlim(t_start, t_stop)

    # Format the ticks
    # plt.locator_params(nbins=nbins)

    import matplotlib.ticker as ticker
    logxscale = 3
    xticks = ticker.FuncFormatter(
        lambda x, pos: '{0:.2f}'.format(x / 10.**logxscale))
    ax.xaxis.set_major_formatter(xticks)
    ax.set_xlabel('Time ($10^{%d}$ s)' % logxscale)

    logyscale = 4
    yticks = ticker.FuncFormatter(
        lambda y, pos: '{0:.3f}'.format(y / 10.**logyscale))
    ax.yaxis.set_major_formatter(yticks)
    ax.set_ylabel('Pred. LL ($ \\times 10^{%d}$)' % logyscale)

    # ylim = ax.get_ylim()
    # ax.plot([t_bfgs, t_bfgs], ylim, '--k')
    # ax.set_ylim(ylim)
    ylim = (-129980, -129840)
    ax.set_ylim(ylim)

    # plt.tight_layout()
    plt.subplots_adjust(bottom=0.2, left=0.2)
    # plt.title("Predictive Log Likelihood ($T=%d$)" % T_train)
    plt.show()
    fig.savefig('figure2b.pdf')
예제 #36
0
def plot_mean_and_pca_locations(result):
    ### Plot the sampled locations for a few neurons
    _, _, _, _, Ls = result
    Ls_rot = []
    for L in Ls:
        R = compute_optimal_rotation(L, pfs, scale=False)
        Ls_rot.append(L.dot(R))
    Ls_rot = np.array(Ls_rot)
    Ls_mean = np.mean(Ls_rot, 0)

    # Bin the data
    from pyhawkes.utils.utils import convert_continuous_to_discrete
    S_dt = convert_continuous_to_discrete(S, C, 0.25, 0, T)

    # Smooth the data to get a firing rate
    from scipy.ndimage.filters import gaussian_filter1d
    S_smooth = np.array([gaussian_filter1d(s, 4) for s in S_dt.T]).T

    # Run pca to gte an embedding
    from sklearn.decomposition import PCA
    pca = PCA(n_components=2)
    pca.fit(S_smooth)
    Z = pca.components_.T

    # Rotate
    R = compute_optimal_rotation(Z, pfs, scale=False)
    Z = Z.dot(R)

    wheel_cmap = gradient_cmap([colors[0], colors[3], colors[2], colors[1], colors[0]])
    fig = create_figure(figsize=(1.4,2.9))
    # plt.subplot(211, aspect='equal')
    ax = create_axis_at_location(fig, .3, 1.7, 1, 1)


    for i,k in enumerate(node_perm):
        color = wheel_cmap((np.pi+pfs_th[k])/(2*np.pi))
        plt.plot(Ls_mean[k,0], Ls_mean[k, 1], 'o',
                 markerfacecolor=color, markeredgecolor=color,
                 markersize=4 + 4 * pf_size[k],
                 alpha=0.7)

    #     plt.gca().add_patch(Circle((0,0), radius=rad, ec='k', fc="none"))
    plt.title("Mean Locations")
    plt.xlim(-3, 3)
    plt.xticks([-2, 0, 2])
    plt.ylim(-3, 3)
    plt.yticks([-2, 0, 2])


    # plt.subplot(212, aspect='equal')
    ax = create_axis_at_location(fig, .3, .2, 1, 1)

    for i,k in enumerate(node_perm):
        color = wheel_cmap((np.pi+pfs_th[k])/(2*np.pi))
        plt.plot(Z[k,0], Z[k, 1], 'o',
                 markerfacecolor=color, markeredgecolor=color,
                 markersize=4 + 4 * pf_size[k],
                 alpha=0.7)

    #     plt.gca().add_patch(Circle((0,0), radius=rad, ec='k', fc="none"))
    plt.title("PCA Locations")
    plt.xlim(-.5, .5)
    # plt.xlabel("$x$")
    plt.xticks([-.4, 0, .4])
    plt.ylim(-.5, .5)
    # plt.ylabel("$y$")
    plt.yticks([-.4, 0, .4])

    # plt.tight_layout()
    plt.savefig(os.path.join(results_dir, "hipp_mean_pca_locations.pdf"))
    plt.show()
                    fit_continuous_time_model_gibbs(S_ct, C_ct, N_samples))

        with open(res_file, "w") as f:
            cPickle.dump((events_per_bin, dt_times, ct_times), f, protocol=-1)

    events_per_bin = np.array(events_per_bin)
    dt_times = np.array(dt_times)
    ct_times = np.array(ct_times)
    perm = np.argsort(events_per_bin)

    events_per_bin = events_per_bin[perm]
    dt_times = dt_times[perm]
    ct_times = ct_times[perm]

    # Plot the results
    fig = create_figure(figsize=(2.5, 2.5))
    fig.set_tight_layout(True)
    ax = fig.add_subplot(111)

    # Plot DT data
    ax.plot(events_per_bin,
            dt_times,
            'o',
            linestyle="none",
            markerfacecolor=colors[2],
            markeredgecolor=colors[2],
            markersize=4,
            label="Discrete")

    # Plot linear fit
    p_dt = np.poly1d(np.polyfit(events_per_bin, dt_times, deg=1))
예제 #38
0
def plot_correlation_matrix(Sigma,
                            betas,
                            words,
                            results_dir,
                            outname="corr_matrix.pdf",
                            blockify=False,
                            highlight=[]):

    # Get topic names
    topic_names = [np.array(words)[np.argmax(beta)]  for beta in betas.T]

    # Plot the log likelihood
    sz = 5.25/3.  # Three NIPS panels
    fig = create_figure(figsize=(sz, 2.5), transparent=True)
    fig.set_tight_layout(True)
    ax = fig.add_subplot(111)

    C = corr_matrix(Sigma)
    T = C.shape[0]
    lim = abs(C).max()
    cmap = gradient_cmap([colors[1], np.ones(3), colors[0]])

    if blockify:
        perm = find_blockifying_perm(C, k=4, nclusters=4)
        C = C[np.ix_(perm, perm)]

    im = plt.imshow(np.kron(C, np.ones((50,50))), interpolation="none", vmin=-lim, vmax=lim, cmap=cmap, extent=(1,T+1,T+1,1))

    from mpl_toolkits.axes_grid1 import make_axes_locatable
    divider = make_axes_locatable(ax)
    cax = divider.append_axes("bottom", size="5%", pad=0.05)
    cbar = plt.colorbar(im, cax=cax,
                        orientation="horizontal",
                        ticks=[-1, -0.5, 0., 0.5, 1.0],
                        label="Topic Correlation")
    # cbar.set_label("Probability", labelpad=10)
    plt.subplots_adjust(left=0.05, bottom=0.1, top=0.9, right=0.85)

    # Highlight some cells
    import string
    from matplotlib.patches import Rectangle
    for i,(j,k) in enumerate(highlight):
        ax.add_patch(Rectangle((k+1, j+1), 1, 1, facecolor="none", edgecolor='k', linewidth=1))

        ax.text(k+1-1.5,j+1+1,string.ascii_lowercase[i], )

        print("")
        print("CC: ", C[j,k])
        print("Topic ", j)
        print(top_k(words, betas[:,j]))
        print("Topic ", k)
        print(top_k(words, betas[:,k]))
        print("")

    # Find the most correlated off diagonal entry
    C_offdiag = np.tril(C,k=-1)
    sorted_pairs = np.argsort(C_offdiag.ravel())
    for i in xrange(5):
        print("")
        imax,jmax = np.unravel_index(sorted_pairs[-i], (T,T))
        print("Correlated Topics (%d, %d): " % (imax, jmax))
        print(top_k(words, betas[:,imax]), "\n and \n", top_k(words, betas[:,jmax]))
        print("correlation coeff: ", C[imax, jmax])
        print("-" * 50)
        print("")

    print("-" * 50)
    print("-" * 50)
    print("-" * 50)

    for i in xrange(5):
        print("")
        imin,jmin = np.unravel_index(sorted_pairs[i], (T,T))
        print("Anticorrelated Topics (%d, %d): " % (imin, jmin))
        # print topic_names[imin], " and ", topic_names[jmin]
        print(top_k(words, betas[:,imin]), "\n and \n", top_k(words, betas[:,jmin]))
        print("correlation coeff: ", C[imin, jmin])
        print("-" * 50)
        print("")


    # Move main axis ticks to top
    ax.xaxis.tick_top()
    # ax.set_title("Topic Correlation", y=1.1)
    fig.savefig(os.path.join(results_dir, outname))

    plt.show()
예제 #39
0
def plot_results(result):
    lls, plls, Weffs, Ps, Ls = result

    ### Colored locations
    wheel_cmap = gradient_cmap([colors[0], colors[3], colors[2], colors[1], colors[0]])
    fig = create_figure(figsize=(1.8, 1.8))
    # ax = create_axis_at_location(fig, .1, .1, 1.5, 1.5, box=False)
    ax = create_axis_at_location(fig, .6, .4, 1.1, 1.1)

    for i,k in enumerate(node_perm):
        color = wheel_cmap((np.pi+pfs_th[k])/(2*np.pi))
        # alpha = pfs_rad[k] / 47
        alpha = 0.7
        ax.add_patch(Circle((pfs[k,0], pfs[k,1]),
                            radius=3+4*pf_size[k],
                            color=color, ec="none",
                            alpha=alpha)
                            )

    plt.title("True place fields")
    # ax.text(0, 45, "True Place Fields",
    #         horizontalalignment="center",
    #         fontdict=dict(size=9))
    plt.xlim(-45,45)
    plt.xticks([-40, -20, 0, 20, 40])
    plt.xlabel("$x$ [cm]")
    plt.ylim(-45,45)
    plt.yticks([-40, -20, 0, 20, 40])
    plt.ylabel("$y$ [cm]")
    plt.savefig(os.path.join(results_dir, "hipp_colored_locations.pdf"))


    # Plot the inferred weighted adjacency matrix
    fig = create_figure(figsize=(1.8, 1.8))
    ax = create_axis_at_location(fig, .4, .4, 1.1, 1.1)

    Weff = np.array(Weffs[N_samples//2:]).mean(0)
    Weff = Weff[np.ix_(node_perm, node_perm)]
    lim = Weff[(1-np.eye(K)).astype(np.bool)].max()
    im = ax.imshow(np.kron(Weff, np.ones((20,20))),
                   interpolation="none", cmap="Greys", vmax=lim)
    ax.set_xticks([])
    ax.set_yticks([])

    # node_colors = wheel_cmap()
    node_values = ((np.pi+pfs_th[node_perm])/(2*np.pi))[:,None] *np.ones((K,2))
    yax = create_axis_at_location(fig, .2, .4, .3, 1.1)
    remove_plot_labels(yax)
    yax.imshow(node_values, interpolation="none",
               cmap=wheel_cmap)
    yax.set_xticks([])
    yax.set_yticks([])
    yax.set_ylabel("pre")

    xax = create_axis_at_location(fig, .4, .2, 1.1, .3)
    remove_plot_labels(xax)
    xax.imshow(node_values.T, interpolation="none",
               cmap=wheel_cmap)
    xax.set_xticks([])
    xax.set_yticks([])
    xax.set_xlabel("post")

    cbax = create_axis_at_location(fig, 1.55, .4, .04, 1.1)
    plt.colorbar(im, cax=cbax, ticks=[0, .1, .2,  .3])
    cbax.tick_params(labelsize=8, pad=1)
    cbax.set_ticklabels=["0", ".1", ".2",  ".3"]

    ax.set_title("Inferred Weights")
    plt.savefig(os.path.join(results_dir, "hipp_W.pdf"))

    # # Plot the inferred connection probability
    # plt.figure()
    # plt.imshow(P, interpolation="none", cmap="Greys", vmin=0)
    # plt.colorbar()

        # Plot the inferred weighted adjacency matrix
    fig = create_figure(figsize=(1.8, 1.8))
    ax = create_axis_at_location(fig, .4, .4, 1.1, 1.1)

    P = np.array(Ps[N_samples//2:]).mean(0)
    P = P[np.ix_(node_perm, node_perm)]
    im = ax.imshow(np.kron(P, np.ones((20,20))),
                   interpolation="none", cmap="Greys", vmin=0, vmax=1)
    ax.set_xticks([])
    ax.set_yticks([])

    # node_colors = wheel_cmap()
    node_values = ((np.pi+pfs_th[node_perm])/(2*np.pi))[:,None] *np.ones((K,2))
    yax = create_axis_at_location(fig, .2, .4, .3, 1.1)
    remove_plot_labels(yax)
    yax.imshow(node_values, interpolation="none",
               cmap=wheel_cmap)
    yax.set_xticks([])
    yax.set_yticks([])
    yax.set_ylabel("pre")

    xax = create_axis_at_location(fig, .4, .2, 1.1, .3)
    remove_plot_labels(xax)
    xax.imshow(node_values.T, interpolation="none",
               cmap=wheel_cmap)
    xax.set_xticks([])
    xax.set_yticks([])
    xax.set_xlabel("post")

    cbax = create_axis_at_location(fig, 1.55, .4, .04, 1.1)
    plt.colorbar(im, cax=cbax, ticks=[0, .5, 1])
    cbax.tick_params(labelsize=8, pad=1)
    cbax.set_ticklabels=["0.0", "0.5",  "1.0"]

    ax.set_title("Inferred Probability")
    plt.savefig(os.path.join(results_dir, "hipp_P.pdf"))


    plt.show()
예제 #40
0
def plot_census_results(train, samples, test, test_pis):
    # Extract samp[les
    train_mus = np.array([s[0] for s in samples])
    train_psis = np.array([s[1][0][0] for s in samples])
    # omegas = np.array([s[1][0][1] for s in samples])

    # Adjust psis by the mean and compute the inferred pis
    train_psis += train_mus[0][None,None,:]
    train_pis = np.array([psi_to_pi(psi_sample) for psi_sample in train_psis])
    train_pi_mean = np.mean(train_pis, axis=0)
    train_pi_std = np.std(train_pis, axis=0)

    # Compute test pi mean and std
    test_pi_mean = np.mean(test_pis, axis=0)
    test_pi_std = np.std(test_pis, axis=0)

    # Compute empirical probabilities
    train_pi_emp = train.data / train.data.sum(axis=1)[:,None]
    test_pi_emp = test.data / test.data.sum(axis=1)[:,None]


    # Plot the temporal trajectories for a few names
    names = ["Scott", "Matthew", "Ethan"]
    states = ["NY", "TX", "WA"]
    linestyles = ["-", "--", ":"]

    fig = create_figure(figsize=(3., 3))
    ax1 = create_axis_at_location(fig, 0.6, 0.5, 2.25, 1.75)
    for name, color in zip(names, colors):
        for state, linestyle in zip(states, linestyles):
            train_state_inds = (train.states == state)
            train_name_ind = np.array(train.names) == name.lower()
            train_years = train.years[train.states == state]
            train_mean_name = train_pi_mean[train_state_inds, train_name_ind]
            train_std_name = train_pi_std[train_state_inds, train_name_ind]

            test_state_inds = (test.states == state)
            test_name_ind = np.array(test.names) == name.lower()
            test_years = test.years[test.states == state]
            test_mean_name = test_pi_mean[test_state_inds, test_name_ind]
            test_std_name = test_pi_std[test_state_inds, test_name_ind]

            years = np.concatenate((train_years, test_years))
            mean_name = np.concatenate((train_mean_name, test_mean_name))
            std_name = np.concatenate((train_std_name, test_std_name))

            # Sausage plot
            sausage_plot(years, mean_name, std_name,
                         color=color, alpha=0.5)

            # Plot inferred mean
            plt.plot(years, mean_name,
                     color=color, label="%s, %s" % (name, state),
                     ls=linestyle, lw=2)

            # Plot empirical probabilities
            plt.plot(train.years[train_state_inds],
                     train_pi_emp[train_state_inds, train_name_ind],
                     color=color,
                     ls="", marker="x", markersize=4)

            plt.plot(test.years[test_state_inds],
                     test_pi_emp[test_state_inds, test_name_ind],
                     color=color,
                     ls="", marker="x", markersize=4)

    # Plot a vertical line to divide train and test
    ylim = plt.gca().get_ylim()
    plt.plot((test.years.min()-0.5) * np.ones(2), ylim, ':k', lw=0.5)
    plt.ylim(ylim)

    # plt.legend(loc="outside right")
    plt.legend(bbox_to_anchor=(0., 1.05, 1., .105), loc=3,
               ncol=len(names), mode="expand", borderaxespad=0.,
               fontsize="x-small")

    plt.xlabel("Year")
    plt.xlim(train.years.min(), test.years.max()+0.1)
    plt.ylabel("Probability")

    # plt.tight_layout()
    fig.savefig("census_gp_rates.pdf")

    plt.show()
    plt.pause(0.1)