def main():
    expts = lab.ExperimentSet(
        os.path.join(df.metadata_path, 'expt_metadata.xml'),
        behaviorDataPath=os.path.join(df.data_path, 'behavior'),
        dataPath=os.path.join(df.data_path, 'imaging'))

    sal_grp = lab.classes.HiddenRewardExperimentGroup.from_json(
        sal_json, expts, label='saline to muscimol')
    mus_grp = lab.classes.HiddenRewardExperimentGroup.from_json(
        mus_json, expts, label='muscimol to saline')

    fig = plt.figure(figsize=(8.5, 11))
    gs = plt.GridSpec(1, 1, top=0.9, bottom=0.7, left=0.1, right=0.4)
    ax = fig.add_subplot(gs[0, 0])

    for expt in mus_grp:
        if 'saline' in expt.get('drug'):
            expt.attrib['drug_condition'] = 'reversal'
        elif 'muscimol' in expt.get('drug'):
            expt.attrib['drug_condition'] = 'learning'
    for expt in sal_grp:
        if 'saline' in expt.get('drug'):
            expt.attrib['drug_condition'] = 'learning'
        elif 'muscimol' in expt.get('drug'):
            expt.attrib['drug_condition'] = 'reversal'

    plotting.plot_metric(
        ax, [sal_grp, mus_grp], metric_fn=ra.fraction_licks_in_reward_zone,
        label_groupby=False, plotby=['X_drug_condition'],
        plot_method='swarm', rotate_labels=False,
        activity_label='Fraction of licks in reward zone',
        colors=sns.color_palette('deep'), plot_bar=True)
    ax.set_yticks([0, 0.1, 0.2, 0.3, 0.4])
    ax.set_ylim(top=0.4)
    ax.set_xticklabels(['Days 1-3', 'Day 4'])

    sns.despine(fig)
    ax.set_title('')
    ax.set_xlabel('')

    misc.save_figure(
        fig, filename, save_dir=save_dir)

    plt.close('all')
def main():
    all_grps = df.loadExptGrps('GOL')

    WT_expt_grp = all_grps['WT_place_set']
    Df_expt_grp = all_grps['Df_place_set']
    expt_grps = [WT_expt_grp, Df_expt_grp]
    if MALES_ONLY:
        for expt_grp in expt_grps:
            expt_grp.filter(lambda expt: expt.parent.get('sex') == 'M')

    WT_label = WT_expt_grp.label()
    Df_label = Df_expt_grp.label()

    fig = plt.figure(figsize=(8.5, 11))

    gs1 = plt.GridSpec(2,
                       5,
                       left=0.1,
                       right=0.3,
                       top=0.90,
                       bottom=0.67,
                       hspace=0.2)
    gs1_2 = plt.GridSpec(2,
                         5,
                         left=0.3,
                         right=0.5,
                         top=0.90,
                         bottom=0.67,
                         hspace=0.2)
    WT_1_heatmap_ax = fig.add_subplot(gs1[0, :-1])
    WT_3_heatmap_ax = fig.add_subplot(gs1_2[0, :-1])
    Df_1_heatmap_ax = fig.add_subplot(gs1[1, :-1])
    Df_3_heatmap_ax = fig.add_subplot(gs1_2[1, :-1])

    gs_cbar = plt.GridSpec(2,
                           10,
                           left=0.3,
                           right=0.5,
                           top=0.90,
                           bottom=0.67,
                           hspace=0.2)
    WT_colorbar_ax = fig.add_subplot(gs_cbar[0, -1])
    Df_colorbar_ax = fig.add_subplot(gs_cbar[1, -1])

    gs2 = plt.GridSpec(1, 10, left=0.1, right=0.5, top=0.6, bottom=0.45)
    pf_close_fraction_ax = fig.add_subplot(gs2[0, :4])
    pf_close_behav_corr_ax = fig.add_subplot(gs2[0, 5:])

    frac_near_range_2 = (-0.051, 0.551)
    behav_range_2 = (-0.051, 0.551)

    #
    # Heatmaps
    #

    WT_cmap = sns.light_palette(WT_color, as_cmap=True)
    WT_dataframe = lab.ExperimentGroup.dataframe(
        WT_expt_grp, include_columns=['X_condition', 'X_day', 'X_session'])

    WT_1_expt_grp = WT_expt_grp.subGroup(
        list(WT_dataframe[(WT_dataframe['X_condition'] == 'C')
                          & (WT_dataframe['X_day'] == '0') &
                          (WT_dataframe['X_session'] == '0')]['expt']))
    place.plotPositionHeatmap(WT_1_expt_grp,
                              roi_filter=WT_filter,
                              ax=WT_1_heatmap_ax,
                              norm='individual',
                              cbar_visible=False,
                              cmap=WT_cmap,
                              plotting_order='place_cells_only',
                              show_belt=False,
                              reward_in_middle=True)
    fix_heatmap_ax(WT_1_heatmap_ax, WT_1_expt_grp)
    WT_1_heatmap_ax.set_title(r'Condition $\mathrm{III}$: Day 1')
    WT_1_heatmap_ax.set_ylabel(WT_label)
    WT_1_heatmap_ax.set_xlabel('')

    WT_3_expt_grp = WT_expt_grp.subGroup(
        list(WT_dataframe[(WT_dataframe['X_condition'] == 'C')
                          & (WT_dataframe['X_day'] == '2') &
                          (WT_dataframe['X_session'] == '0')]['expt']))
    place.plotPositionHeatmap(WT_3_expt_grp,
                              roi_filter=WT_filter,
                              ax=WT_3_heatmap_ax,
                              norm='individual',
                              cbar_visible=True,
                              cax=WT_colorbar_ax,
                              cmap=WT_cmap,
                              plotting_order='place_cells_only',
                              show_belt=False,
                              reward_in_middle=True)
    fix_heatmap_ax(WT_3_heatmap_ax, WT_3_expt_grp)
    WT_3_heatmap_ax.set_title(r'Condition $\mathrm{III}$: Day 3')
    WT_3_heatmap_ax.set_ylabel('')
    WT_3_heatmap_ax.set_xlabel('')
    WT_colorbar_ax.set_yticklabels(['Min', 'Max'])

    Df_cmap = sns.light_palette(Df_color, as_cmap=True)
    Df_dataframe = lab.ExperimentGroup.dataframe(
        Df_expt_grp, include_columns=['X_condition', 'X_day', 'X_session'])

    Df_1_expt_grp = Df_expt_grp.subGroup(
        list(Df_dataframe[(Df_dataframe['X_condition'] == 'C')
                          & (Df_dataframe['X_day'] == '0') &
                          (Df_dataframe['X_session'] == '2')]['expt']))
    place.plotPositionHeatmap(Df_1_expt_grp,
                              roi_filter=Df_filter,
                              ax=Df_1_heatmap_ax,
                              norm='individual',
                              cbar_visible=False,
                              cmap=Df_cmap,
                              plotting_order='place_cells_only',
                              show_belt=False,
                              reward_in_middle=True)
    fix_heatmap_ax(Df_1_heatmap_ax, Df_1_expt_grp)
    Df_1_heatmap_ax.set_ylabel(Df_label)

    Df_3_expt_grp = Df_expt_grp.subGroup(
        list(Df_dataframe[(Df_dataframe['X_condition'] == 'C')
                          & (Df_dataframe['X_day'] == '2') &
                          (Df_dataframe['X_session'] == '0')]['expt']))
    place.plotPositionHeatmap(Df_3_expt_grp,
                              roi_filter=Df_filter,
                              ax=Df_3_heatmap_ax,
                              norm='individual',
                              cbar_visible=True,
                              cax=Df_colorbar_ax,
                              cmap=Df_cmap,
                              plotting_order='place_cells_only',
                              show_belt=False,
                              reward_in_middle=True)
    fix_heatmap_ax(Df_3_heatmap_ax, Df_3_expt_grp)
    Df_3_heatmap_ax.set_ylabel('')
    Df_colorbar_ax.set_yticklabels(['Min', 'Max'])

    #
    # Fraction of PCs near reward
    #

    activity_metric = place.centroid_to_position_threshold
    activity_kwargs = {
        'method': 'resultant_vector',
        'positions': 'reward',
        'pcs_only': True,
        'threshold': np.pi / 8
    }
    behavior_fn = ra.fraction_licks_in_reward_zone
    behavior_kwargs = {}
    behavior_label = 'Fraction of licks in reward zone'

    plotting.plot_metric(pf_close_fraction_ax,
                         expt_grps,
                         metric_fn=activity_metric,
                         roi_filters=roi_filters,
                         groupby=[['expt', 'X_condition', 'X_day']],
                         plotby=['X_condition', 'X_day'],
                         plot_abs=False,
                         plot_method='line',
                         activity_kwargs=activity_kwargs,
                         rotate_labels=False,
                         activity_label='Fraction of place cells near reward',
                         label_every_n=1,
                         colors=colors,
                         markers=markers,
                         markersize=5,
                         return_full_dataframes=False,
                         linestyles=linestyles)
    pf_close_fraction_ax.axhline(1 / 8., linestyle='--', color='k')
    pf_close_fraction_ax.set_title('')
    sns.despine(ax=pf_close_fraction_ax)
    pf_close_fraction_ax.set_xlabel('Day in Condition')
    day_number_only_label(pf_close_fraction_ax)
    label_conditions(pf_close_fraction_ax)
    pf_close_fraction_ax.legend(loc='upper left', fontsize=6)
    pf_close_fraction_ax.set_ylim(0, 0.40)
    pf_close_fraction_ax.set_yticks([0, 0.1, 0.2, 0.3, 0.4])

    scatter_kws = {'s': 5}
    colorby_list = [(expt_grp.label(), 'C') for expt_grp in expt_grps]
    pf_close_behav_corr_ax.set_xlim(frac_near_range_2)
    pf_close_behav_corr_ax.set_ylim(behav_range_2)
    plotting.plot_paired_metrics(
        expt_grps,
        first_metric_fn=place.centroid_to_position_threshold,
        second_metric_fn=behavior_fn,
        roi_filters=roi_filters,
        groupby=(('expt', ), ),
        colorby=('expt_grp', 'X_condition'),
        filter_fn=lambda df: df['X_condition'] == 'C',
        filter_columns=['X_condition'],
        first_metric_kwargs=activity_kwargs,
        second_metric_kwargs=behavior_kwargs,
        first_metric_label='Fraction of place cells near reward',
        second_metric_label=behavior_label,
        shuffle_colors=False,
        fit_reg=True,
        plot_method='regplot',
        colorby_list=colorby_list,
        colors=colors,
        markers=markers,
        ax=pf_close_behav_corr_ax,
        scatter_kws=scatter_kws,
        truncate=False,
        linestyles=linestyles)
    pf_close_behav_corr_ax.set_xlim(frac_near_range_2)
    pf_close_behav_corr_ax.set_ylim(behav_range_2)
    pf_close_behav_corr_ax.tick_params(direction='in')
    pf_close_behav_corr_ax.get_legend().set_visible(False)
    pf_close_behav_corr_ax.legend(loc='upper left', fontsize=6)

    misc.save_figure(fig, filename, save_dir=save_dir)

    plt.close('all')
def main():
    hidden_grps = df.loadExptGrps('GOL')

    WT_expt_grp_hidden = hidden_grps['WT_place_set']
    Df_expt_grp_hidden = hidden_grps['Df_place_set']
    expt_grps_hidden = [WT_expt_grp_hidden, Df_expt_grp_hidden]

    acute_grps = df.loadExptGrps('RF')

    WT_expt_grp_acute = acute_grps['WT_place_set'].unpair()
    Df_expt_grp_acute = acute_grps['Df_place_set'].unpair()
    expt_grps_acute = [WT_expt_grp_acute, Df_expt_grp_acute]

    WT_label = WT_expt_grp_hidden.label()
    Df_label = Df_expt_grp_hidden.label()
    labels = [WT_label, Df_label]

    fig = plt.figure(figsize=(8.5, 11))

    gs1 = plt.GridSpec(1, 1, top=0.9, bottom=0.7, left=0.1, right=0.20)
    across_ctx_ax = fig.add_subplot(gs1[0, 0])

    gs2 = plt.GridSpec(3, 1, top=0.9, bottom=0.7, left=0.25, right=0.35)
    wt_pie_ax = fig.add_subplot(gs2[0, 0])
    df_pie_ax = fig.add_subplot(gs2[1, 0])
    shuffle_pie_ax = fig.add_subplot(gs2[2, 0])
    pie_axs = (wt_pie_ax, df_pie_ax, shuffle_pie_ax)

    gs3 = plt.GridSpec(1, 1, top=0.9, bottom=0.7, left=0.4, right=0.5)
    cue_cell_bar_ax = fig.add_subplot(gs3[0, 0])

    gs5 = plt.GridSpec(1, 1, top=0.5, bottom=0.3, left=0.1, right=0.3)
    acute_stability_ax = fig.add_subplot(gs5[0, 0])

    acute_stability_inset_ax = fig.add_axes([0.23, 0.32, 0.05, 0.08])

    gs6 = plt.GridSpec(1, 1, top=0.5, bottom=0.3, left=0.4, right=0.5)
    task_compare_ax = fig.add_subplot(gs6[0, 0])

    #
    # RF Compare
    #
    params = {}
    params['filename'] = filename

    params_cent_shift_pc = {}
    params_cent_shift_pc['stability_fn'] = place.activity_centroid_shift
    params_cent_shift_pc['stability_kwargs'] = {
        'activity_filter': 'pc_both',
        'circ_var_pcs': False,
        'units': 'norm',
        'shuffle': True
    }
    params_cent_shift_pc['stability_label'] = \
        'Centroid shift (fraction of belt)'

    params_cent_shift_all = {}
    params_cent_shift_all['stability_fn'] = place.activity_centroid_shift
    params_cent_shift_all['stability_kwargs'] = {
        'activity_filter': 'active_both',
        'circ_var_pcs': False,
        'units': 'norm',
        'shuffle': True
    }
    params_cent_shift_all['stability_label'] = \
        'Centroid shift (fraction of belt)'
    params_cent_shift_all['stability_inset_ylim'] = (0.15, 0.30)
    params_cent_shift_all['stability_cdf_range'] = (0.15, 0.35)
    params_cent_shift_all['stability_cdf_ticks'] = \
        (0.15, 0.20, 0.25, 0.30, 0.35)
    params_cent_shift_all['stability_compare_ylim'] = (0.15, 0.27)
    params_cent_shift_all['stability_compare_yticks'] = (0.15, 0.20, 0.25)
    params_cent_shift_all['ctx_compare_ylim'] = (0.10, 0.30)
    params_cent_shift_all['ctx_compare_yticks'] = \
        (0.10, 0.15, 0.20, 0.25, 0.30)

    params_cent_shift_cm = {}
    params_cent_shift_cm['stability_fn'] = place.activity_centroid_shift
    params_cent_shift_cm['stability_kwargs'] = {
        'activity_filter': 'active_both',
        'circ_var_pcs': False,
        'units': 'cm',
        'shuffle': True
    }
    params_cent_shift_cm['stability_label'] = 'Centroid shift (cm)'

    params_pop_vect_corr = {}
    params_pop_vect_corr['stability_fn'] = place.population_vector_correlation
    params_pop_vect_corr['stability_kwargs'] = {
        'method': 'corr',
        'activity_filter': 'pc_both',
        'min_pf_density': 0.05,
        'circ_var_pcs': False
    }
    params_pop_vect_corr['stability_label'] = 'Population vector correlation'

    params_pf_corr = {}
    params_pf_corr['stability_fn'] = place.place_field_correlation
    params_pf_corr['stability_kwargs'] = {'activity_filter': 'pc_either'}
    params_pf_corr['stability_label'] = 'Place field correlation'
    params_pf_corr['stability_inset_ylim'] = (0, 0.50)
    params_pf_corr['stability_cdf_range'] = (0.15, 0.55)
    params_pf_corr['stability_cdf_ticks'] = (0.15, 0.25, 0.35, 0.45, 0.55)
    params_pf_corr['stability_compare_ylim'] = (0.22, 0.40)
    params_pf_corr['stability_compare_yticks'] = (0.25, 0.30, 0.35, 0.40)
    params_pf_corr['hidden_ctx_compare_ylim'] = (0.22, 0.40)
    params_pf_corr['hidden_ctx_compare_yticks'] = (0.25, 0.30, 0.35, 0.40)
    params_pf_corr['ctx_compare_ylim'] = (0.22, 0.40)
    params_pf_corr['ctx_compare_yticks'] = (0.25, 0.30, 0.35, 0.40)

    params.update(params_cent_shift_all)

    day_paired_grps_acute = [
        grp.pair('consecutive groups', groupby=['day_in_df'])
        for grp in expt_grps_acute
    ]
    paired_grps_acute = day_paired_grps_acute
    paired_grps_hidden = [
        grp.pair('consecutive groups', groupby=['X_condition', 'X_day'])
        for grp in expt_grps_hidden
    ]

    filter_fn = lambda df: (df['expt_pair_label'] == 'SameAll')
    filter_columns = ['expt_pair_label']

    acute_stability = plotting.plot_metric(
        acute_stability_ax,
        paired_grps_acute,
        metric_fn=params['stability_fn'],
        groupby=[['expt_pair_label', 'second_expt']],
        plotby=None,
        plot_method='cdf',
        plot_abs=True,
        roi_filters=roi_filters,
        activity_kwargs=params['stability_kwargs'],
        plot_shuffle=True,
        shuffle_plotby=False,
        pool_shuffle=True,
        activity_label=params['stability_label'],
        colors=colors,
        rotate_labels=False,
        filter_fn=filter_fn,
        filter_columns=filter_columns,
        return_full_dataframes=False,
        linestyles=linestyles)
    acute_stability_ax.set_xlabel(params['stability_label'])
    acute_stability_ax.set_title('')
    sns.despine(ax=acute_stability_ax)
    acute_stability_ax.set_xlim(params['stability_cdf_range'])
    acute_stability_ax.set_xticks(params['stability_cdf_ticks'])
    acute_stability_ax.legend(loc='upper left', fontsize=6)

    plotting.plot_metric(acute_stability_inset_ax,
                         paired_grps_acute,
                         metric_fn=params['stability_fn'],
                         groupby=[['second_expt'], ['second_mouseID']],
                         plotby=None,
                         plot_method='swarm',
                         plot_abs=True,
                         roi_filters=roi_filters,
                         activity_kwargs=params['stability_kwargs'],
                         plot_shuffle=True,
                         shuffle_plotby=False,
                         pool_shuffle=True,
                         activity_label=params['stability_label'],
                         colors=colors,
                         rotate_labels=False,
                         filter_fn=filter_fn,
                         filter_columns=filter_columns,
                         linewidth=0.2,
                         edgecolor='gray',
                         plot_shuffle_as_hline=True)
    acute_stability_inset_ax.get_legend().set_visible(False)
    sns.despine(ax=acute_stability_inset_ax)
    acute_stability_inset_ax.set_title('')
    acute_stability_inset_ax.set_ylabel('')
    acute_stability_inset_ax.set_xlabel('')
    acute_stability_inset_ax.tick_params(bottom=False, labelbottom=False)
    acute_stability_inset_ax.set_ylim(params['stability_inset_ylim'])
    acute_stability_inset_ax.set_yticks(params['stability_inset_ylim'])

    tmp_fig = plt.figure()
    tmp_ax = tmp_fig.add_subplot(111)
    hidden_stability = plotting.plot_metric(
        tmp_ax,
        paired_grps_hidden,
        metric_fn=params['stability_fn'],
        groupby=[['expt_pair_label', 'second_expt']],
        plotby=('expt_pair_label', ),
        plot_method='line',
        plot_abs=True,
        roi_filters=roi_filters,
        activity_kwargs=params['stability_kwargs'],
        plot_shuffle=True,
        shuffle_plotby=False,
        pool_shuffle=True,
        activity_label=params['stability_label'],
        colors=colors,
        rotate_labels=False,
        filter_fn=filter_fn,
        filter_columns=filter_columns,
        return_full_dataframes=False)
    plt.close(tmp_fig)

    wt_acute = acute_stability[WT_label]['dataframe']
    wt_acute_shuffle = acute_stability[WT_label]['shuffle']
    df_acute = acute_stability[Df_label]['dataframe']
    df_acute_shuffle = acute_stability[Df_label]['shuffle']

    wt_hidden = hidden_stability[WT_label]['dataframe']
    wt_hidden_shuffle = hidden_stability[WT_label]['shuffle']
    df_hidden = hidden_stability[Df_label]['dataframe']
    df_hidden_shuffle = hidden_stability[Df_label]['shuffle']

    for dataframe in (wt_acute, wt_acute_shuffle, df_acute, df_acute_shuffle):
        dataframe['task'] = 'RF'

    for dataframe in (wt_hidden, wt_hidden_shuffle, df_hidden,
                      df_hidden_shuffle):
        dataframe['task'] = 'GOL'

    WT_data = wt_acute.append(wt_hidden, ignore_index=True)
    Df_data = df_acute.append(df_hidden, ignore_index=True)

    WT_shuffle = wt_acute_shuffle.append(wt_hidden_shuffle, ignore_index=True)
    Df_shuffle = df_acute_shuffle.append(df_hidden_shuffle, ignore_index=True)

    filter_columns = ('expt_pair_label', )
    filter_fn = lambda df: (df['expt_pair_label'] == 'SameAll')

    order_dict = {'RF': 0, 'GOL': 1}
    WT_data['order'] = WT_data['task'].map(order_dict)
    Df_data['order'] = Df_data['task'].map(order_dict)
    line_kwargs = {'markersize': 4}
    plotting.plot_dataframe(task_compare_ax, [WT_data, Df_data],
                            [WT_shuffle, Df_shuffle],
                            labels=labels,
                            activity_label='',
                            groupby=[['task', 'second_mouseID']],
                            plotby=('task', ),
                            plot_method='box_and_line',
                            colors=colors,
                            filter_fn=filter_fn,
                            filter_columns=filter_columns,
                            plot_shuffle=True,
                            shuffle_plotby=False,
                            pool_shuffle=True,
                            orderby='order',
                            notch=False,
                            plot_shuffle_as_hline=True,
                            markers=markers,
                            linestyles=linestyles,
                            line_kwargs=line_kwargs,
                            flierprops={
                                'markersize': 3,
                                'marker': 'o'
                            },
                            whis='range')
    task_compare_ax.set_title('')
    sns.despine(ax=task_compare_ax)
    task_compare_ax.set_ylim(params['stability_compare_ylim'])
    task_compare_ax.set_yticks(params['stability_compare_yticks'])
    task_compare_ax.set_xlabel('')
    task_compare_ax.set_ylabel(params['stability_label'])
    task_compare_ax.legend(loc='upper right', fontsize=6)

    #
    # Stability across transition
    #
    groupby = [['second_expt'], ['second_mouse']]
    filter_fn = lambda df: (df['X_first_condition'] == 'A') \
        & (df['X_second_condition'] == 'B')
    filter_columns = ('X_first_condition', 'X_second_condition')
    plotting.plot_metric(across_ctx_ax,
                         paired_grps_hidden,
                         metric_fn=params['stability_fn'],
                         groupby=groupby,
                         plotby=None,
                         plot_method='swarm',
                         activity_kwargs=params['stability_kwargs'],
                         plot_shuffle=True,
                         shuffle_plotby=False,
                         pool_shuffle=True,
                         colors=colors,
                         activity_label=params['stability_label'],
                         rotate_labels=False,
                         filter_fn=filter_fn,
                         filter_columns=filter_columns,
                         plot_shuffle_as_hline=True,
                         return_full_dataframes=False,
                         plot_bar=True,
                         roi_filters=roi_filters)
    sns.despine(ax=across_ctx_ax)
    across_ctx_ax.set_ylim(0.0, 0.3)
    across_ctx_ax.set_yticks([0.0, 0.1, 0.2, 0.3])
    across_ctx_ax.set_xticklabels([])
    across_ctx_ax.set_xlabel('')
    across_ctx_ax.set_title('')
    across_ctx_ax.get_legend().set_visible(False)
    plotting.stackedText(across_ctx_ax, labels, colors=colors, loc=2, size=10)

    #
    # Cue remapping
    #

    THRESHOLD = 0.05 * 2 * np.pi
    CUENESS_THRESHOLD = 0.33

    def first_cue_position(row):
        expt = row['first_expt']
        cue = row['cue']
        cues = expt.belt().cues(normalized=True)
        first_cue = cues.ix[cues['cue'] == cue]
        pos = (first_cue['start'] + first_cue['stop']) / 2
        angle = pos * 2 * np.pi
        return np.complex(np.cos(angle), np.sin(angle))

    def dotproduct(v1, v2):
        return sum((a * b) for a, b in zip(v1, v2))

    def length(v):
        return math.sqrt(dotproduct(v, v))

    def angle(v1, v2):
        return math.acos(
            np.round(dotproduct(v1, v2) / (length(v1) * length(v2)), 3))

    def distance_to_first_cue(row):
        centroid = row['second_centroid']
        pos = row['first_cue_position']
        return angle((pos.real, pos.imag), (centroid.real, centroid.imag))

    WT_copy = copy(WT_expt_grp_hidden)
    WT_copy.filterby(lambda df: ~df['X_condition'].str.contains('C'),
                     ['X_condition'])
    WT_paired = WT_copy.pair('consecutive groups',
                             groupby=['X_condition',
                                      'X_day']).pair('consecutive groups',
                                                     groupby=['X_condition'])

    Df_copy = copy(Df_expt_grp_hidden)
    Df_copy.filterby(lambda df: ~df['X_condition'].str.contains('C'),
                     ['X_condition'])
    Df_paired = Df_copy.pair('consecutive groups',
                             groupby=['X_condition',
                                      'X_day']).pair('consecutive groups',
                                                     groupby=['X_condition'])

    WT_df, WT_shuffle_df = place.cue_cell_remapping(
        WT_paired,
        roi_filter=WT_filter,
        near_threshold=THRESHOLD,
        activity_filter='active_both',
        circ_var_pcs=False,
        shuffle=True)
    Df_df, Df_shuffle_df = place.cue_cell_remapping(
        Df_paired,
        roi_filter=Df_filter,
        near_threshold=THRESHOLD,
        activity_filter='active_both',
        circ_var_pcs=False,
        shuffle=True)

    shuffle_df = pd.concat([WT_shuffle_df, Df_shuffle_df], ignore_index=True)

    cueness, cueness_fraction = [], []
    cue_n, place_n, neither_n = [], [], []

    for grp_df in (WT_df, Df_df, shuffle_df):

        grp_df['first_cue_position'] = grp_df.apply(first_cue_position, axis=1)

        grp_df['second_distance_to_first_cue_position'] = grp_df.apply(
            distance_to_first_cue, axis=1)

        grp_df['cueness'] = grp_df['second_distance_to_first_cue_position'] / \
            (grp_df['value'] + grp_df['second_distance_to_first_cue_position'])

        plotting.prepare_dataframe(grp_df, ['first_mouse'])
        cueness_fraction.append([[]])
        cue_n.append([])
        place_n.append([])
        neither_n.append([])

        for mouse, mouse_df in grp_df.groupby('first_mouse'):
            cue_n[-1].append((mouse_df['cueness'] >
                              (1 - CUENESS_THRESHOLD)).sum())
            place_n[-1].append((mouse_df['cueness'] < CUENESS_THRESHOLD).sum())
            neither_n[-1].append(mouse_df.shape[0] - cue_n[-1][-1] -
                                 place_n[-1][-1])
            cueness_fraction[-1][0].append(cue_n[-1][-1] /
                                           float(place_n[-1][-1]))
        cueness.append([grp_df['cueness']])

    cue_labels = labels + ['shuffle']

    plotting.swarm_plot(cue_cell_bar_ax,
                        cueness_fraction[:2],
                        condition_labels=labels,
                        colors=colors,
                        plot_bar=True)
    cue_cell_bar_ax.axhline(np.mean(cueness_fraction[-1][0]),
                            ls='--',
                            color='k')
    sns.despine(ax=cue_cell_bar_ax)
    cue_cell_bar_ax.set_ylim(0, 1.5)
    cue_cell_bar_ax.set_yticks([0, 0.5, 1.0, 1.5])
    cue_cell_bar_ax.set_xticklabels([])
    cue_cell_bar_ax.set_xlabel('')
    cue_cell_bar_ax.set_ylabel('Cue-to-position ratio')
    cue_cell_bar_ax.get_legend().set_visible(False)
    plotting.stackedText(cue_cell_bar_ax,
                         labels,
                         colors=colors,
                         loc=2,
                         size=10)

    WT_colors = sns.light_palette(WT_color, 7)[:-6:-2]
    Df_colors = sns.light_palette(Df_color, 7)[:-6:-2]
    shuffle_colors = sns.light_palette('k', 7)[:-6:-2]
    pie_colors = (WT_colors, Df_colors, shuffle_colors)
    pie_labels = ['cue', 'position', 'neither']
    orig_size = mpl.rcParams.get('xtick.labelsize')
    mpl.rcParams['xtick.labelsize'] = 5
    for grp_ax, grp_label, grp_cue_n, grp_place_n, grp_neither_n, p_cs in zip(
            pie_axs, cue_labels, cue_n, place_n, neither_n, pie_colors):
        grp_ax.pie([sum(grp_cue_n),
                    sum(grp_place_n),
                    sum(grp_neither_n)],
                   autopct='%1.0f%%',
                   shadow=False,
                   frame=False,
                   labels=pie_labels,
                   colors=p_cs,
                   textprops={'fontsize': 5})
        grp_ax.set_title(grp_label)
        plotting.square_axis(grp_ax)
    mpl.rcParams['xtick.labelsize'] = orig_size

    misc.save_figure(fig, params['filename'], save_dir=save_dir)

    plt.close('all')
def main():
    WT_params = pkl.load(open(params_path))

    WT_model = em.EnrichmentModel2(**WT_params)
    recur_model = emt.EnrichmentModel2_recur(kappa=1,
                                             span=0.8,
                                             mean_recur=0.4,
                                             **WT_params)
    offset_model = emt.EnrichmentModel2_offset(alpha=0.25, **WT_params)
    var_model = emt.EnrichmentModel2_var(kappa=1,
                                         alpha=10,
                                         mean_k=3,
                                         **WT_params)

    WT_model.initialize(n_cells=1000, flat_tol=1e-6)
    recur_model.initialize_like(WT_model)
    offset_model.initialize_like(WT_model)
    var_model.initialize_like(WT_model)
    initial_mask = WT_model.mask
    initial_positions = WT_model.positions

    recur_masks, recur_positions = [], []
    offset_masks, offset_positions = [], []
    var_masks, var_positions = [], []

    n_runs = 100

    color = sns.xkcd_rgb['forest green']

    for _ in range(n_runs):
        recur_model.initialize(initial_mask=initial_mask,
                               initial_positions=initial_positions)
        offset_model.initialize(initial_mask=initial_mask,
                                initial_positions=initial_positions)
        var_model.initialize(initial_mask=initial_mask,
                             initial_positions=initial_positions)

        recur_model.run(8)
        offset_model.run(8)
        var_model.run(8)

        recur_masks.append(recur_model._masks)
        recur_positions.append(recur_model._positions)

        offset_masks.append(offset_model._masks)
        offset_positions.append(offset_model._positions)

        var_masks.append(var_model._masks)
        var_positions.append(var_model._positions)

    recur_enrich = emp.calc_enrichment(recur_positions, recur_masks)
    offset_enrich = emp.calc_enrichment(offset_positions, offset_masks)
    var_enrich = emp.calc_enrichment(var_positions, var_masks)

    fig, axs = plt.subplots(4,
                            3,
                            figsize=(8.5, 11),
                            gridspec_kw={
                                'bottom': 0.3,
                                'wspace': 0.4
                            })

    show_parameters(axs[:, 0], recur_model, recur_enrich, color=color)
    show_parameters(axs[:, 1], offset_model, offset_enrich, color=color)
    show_parameters(axs[:, 2], var_model, var_enrich, color=color)

    for ax in axs[:, 1:].flat:
        ax.set_ylabel('')
    for ax in axs[:2, :].flat:
        ax.set_xlabel('')
    for ax in axs[:, 0]:
        ax.set_xlabel('')
    for ax in axs[:, 2]:
        ax.set_xlabel('')

    axs[0, 0].set_title('Stable recurrence')
    axs[0, 1].set_title('Shift towards reward')
    axs[0, 2].set_title('Stable position')

    save_figure(fig, filename, save_dir=save_dir)
def main():

    fig, axs = plt.subplots(
        5, 2, figsize=(8.5, 11), gridspec_kw={
            'wspace': 0.4, 'hspace': 0.35, 'right': 0.75, 'top': 0.9,
            'bottom': 0.1})

    #
    # Run the model
    #

    n_cells = 1000
    n_runs = 100
    tol = 1e-4

    WT_params = pkl.load(open(WT_params_path, 'r'))

    model_cls = em.EnrichmentModel2

    WT_model = model_cls(**WT_params)
    flat_model = WT_model.copy()
    flat_model.flatten()

    WT_model.initialize(n_cells=n_cells, flat_tol=tol)
    flat_model.initialize_like(WT_model)
    initial_mask = WT_model.mask
    initial_positions = WT_model.positions

    def run_model(model1, model2=None, initial_positions=None,
                  initial_mask=None, n_runs=100, **interp_kwargs):
        masks, positions = [], []

        model = model1.copy()

        if model2 is not None:
            model.interpolate(model2, **interp_kwargs)

        for _ in range(n_runs):
            model.initialize(
                initial_mask=initial_mask, initial_positions=initial_positions)

            model.run(8)

            masks.append(model._masks)
            positions.append(model._positions)

        return positions, masks

    WT_no_swap_pos, WT_no_swap_masks = run_model(
        WT_model, initial_positions=initial_positions,
        initial_mask=initial_mask, n_runs=n_runs)
    flat_no_swap_pos, flat_no_swap_masks = run_model(
        flat_model, initial_positions=initial_positions,
        initial_mask=initial_mask, n_runs=n_runs)

    WT_swap_on_pos, WT_swap_on_masks = run_model(
        WT_model, flat_model, initial_positions=initial_positions,
        initial_mask=initial_mask, n_runs=n_runs, on=1)
    flat_swap_on_pos, flat_swap_on_masks = run_model(
        flat_model, WT_model, initial_positions=initial_positions,
        initial_mask=initial_mask, n_runs=n_runs, on=1)

    WT_swap_recur_pos, WT_swap_recur_masks = run_model(
        WT_model, flat_model, initial_positions=initial_positions,
        initial_mask=initial_mask, n_runs=n_runs, recur=1)
    flat_swap_recur_pos, flat_swap_recur_masks = run_model(
        flat_model, WT_model, initial_positions=initial_positions,
        initial_mask=initial_mask, n_runs=n_runs, recur=1)

    WT_swap_shift_b_pos, WT_swap_shift_b_masks = run_model(
        WT_model, flat_model, initial_positions=initial_positions,
        initial_mask=initial_mask, n_runs=n_runs, shift_b=1)
    flat_swap_shift_b_pos, flat_swap_shift_b_masks = run_model(
        flat_model, WT_model, initial_positions=initial_positions,
        initial_mask=initial_mask, n_runs=n_runs, shift_b=1)

    WT_swap_shift_k_pos, WT_swap_shift_k_masks = run_model(
        WT_model, flat_model, initial_positions=initial_positions,
        initial_mask=initial_mask, n_runs=n_runs, shift_k=1)
    flat_swap_shift_k_pos, flat_swap_shift_k_masks = run_model(
        flat_model, WT_model, initial_positions=initial_positions,
        initial_mask=initial_mask, n_runs=n_runs, shift_k=1)

    #
    # Distance to reward
    #

    WT_no_swap_enrich = emp.calc_enrichment(
        WT_no_swap_pos, WT_no_swap_masks)
    flat_no_swap_enrich = emp.calc_enrichment(
        flat_no_swap_pos, flat_no_swap_masks)

    WT_swap_recur_enrich = emp.calc_enrichment(
        WT_swap_recur_pos, WT_swap_recur_masks)
    flat_swap_recur_enrich = emp.calc_enrichment(
        flat_swap_recur_pos, flat_swap_recur_masks)

    WT_swap_shift_b_enrich = emp.calc_enrichment(
        WT_swap_shift_b_pos, WT_swap_shift_b_masks)
    flat_swap_shift_b_enrich = emp.calc_enrichment(
        flat_swap_shift_b_pos, flat_swap_shift_b_masks)

    WT_swap_shift_k_enrich = emp.calc_enrichment(
        WT_swap_shift_k_pos, WT_swap_shift_k_masks)
    flat_swap_shift_k_enrich = emp.calc_enrichment(
        flat_swap_shift_k_pos, flat_swap_shift_k_masks)

    emp.plot_enrichment(
        axs[0, 0], WT_no_swap_enrich, WT_color, '', rad=False)
    emp.plot_enrichment(
        axs[0, 0], flat_no_swap_enrich, flat_color, '', rad=False)
    plotting.right_label(axs[0, 1], 'No swap')
    axs[0, 0].set_ylabel('')

    emp.plot_enrichment(
        axs[1, 0], WT_swap_recur_enrich, WT_color, '', rad=False)
    emp.plot_enrichment(
        axs[1, 0], flat_swap_recur_enrich, flat_color, '', rad=False)
    plotting.right_label(axs[1, 1], 'Swap recurrence')

    emp.plot_enrichment(
        axs[2, 0], WT_swap_shift_k_enrich, WT_color, '', rad=False)
    emp.plot_enrichment(
        axs[2, 0], flat_swap_shift_k_enrich, flat_color, '', rad=False)
    plotting.right_label(axs[2, 1], 'Swap shift variance')
    axs[2, 0].set_ylabel('')

    emp.plot_enrichment(
        axs[3, 0], WT_swap_shift_b_enrich, WT_color, '', rad=False)
    emp.plot_enrichment(
        axs[3, 0], flat_swap_shift_b_enrich, flat_color, '', rad=False)
    plotting.right_label(axs[3, 1], 'Swap shift offset')
    axs[3, 0].set_ylabel('')

    plotting.stackedText(
        axs[0, 0], [WT_label, flat_label], colors=colors, loc=2, size=10)

    #
    # Final distribution
    #

    WT_no_swap_final_dist = emp.calc_final_distributions(
        WT_no_swap_pos, WT_no_swap_masks)
    flat_no_swap_final_dist = emp.calc_final_distributions(
        flat_no_swap_pos, flat_no_swap_masks)

    WT_swap_recur_final_dist = emp.calc_final_distributions(
        WT_swap_recur_pos, WT_swap_recur_masks)
    flat_swap_recur_final_dist = emp.calc_final_distributions(
        flat_swap_recur_pos, flat_swap_recur_masks)

    WT_swap_shift_b_final_dist = emp.calc_final_distributions(
        WT_swap_shift_b_pos, WT_swap_shift_b_masks)
    flat_swap_shift_b_final_dist = emp.calc_final_distributions(
        flat_swap_shift_b_pos, flat_swap_shift_b_masks)

    WT_swap_shift_k_final_dist = emp.calc_final_distributions(
        WT_swap_shift_k_pos, WT_swap_shift_k_masks)
    flat_swap_shift_k_final_dist = emp.calc_final_distributions(
        flat_swap_shift_k_pos, flat_swap_shift_k_masks)

    emp.plot_final_distributions(
        axs[0, 1], [WT_no_swap_final_dist, flat_no_swap_final_dist], colors,
        labels=None, title='', rad=False)
    emp.plot_final_distributions(
        axs[1, 1], [WT_swap_recur_final_dist, flat_swap_recur_final_dist],
        colors, labels=None, title='', rad=False)
    emp.plot_final_distributions(
        axs[2, 1], [WT_swap_shift_k_final_dist, flat_swap_shift_k_final_dist],
        colors, labels=None, title='', rad=False)
    emp.plot_final_distributions(
        axs[3, 1], [WT_swap_shift_b_final_dist, flat_swap_shift_b_final_dist],
        colors, labels=None, title='', rad=False)

    for ax in axs[:3, :].flat:
        ax.set_xlabel('')
    for ax in axs[4]:
        ax.set_visible(False)

    misc.save_figure(fig, filename, save_dir=save_dir)

    plt.close('all')
def main():

    WT_params = pkl.load(open(WT_params_path))
    Df_params = pkl.load(open(Df_params_path))

    WT_raw_data, WT_data = emd.load_data('wt', root=data_root_path)
    Df_raw_data, Df_data = emd.load_data('df', root=data_root_path)

    fig, axs = plt.subplots(3,
                            3,
                            figsize=(11, 8.5),
                            gridspec_kw={'wspace': 0.4})

    for ax in axs[1:, :].flat:
        ax.set_visible(False)

    #
    # 2 session stability
    #

    WT_2_shift_data = emd.tripled_activity_centroid_distance_to_reward(
        WT_data, prev_imaged=False)

    WT_2_shift_data = WT_2_shift_data.dropna(subset=['first', 'third'])
    WT_2_shifts = WT_2_shift_data['third'] - WT_2_shift_data['first']
    WT_2_shifts[WT_2_shifts < -np.pi] += 2 * np.pi
    WT_2_shifts[WT_2_shifts >= np.pi] -= 2 * np.pi

    WT_2_shift_data_prev = emd.tripled_activity_centroid_distance_to_reward(
        WT_data, prev_imaged=True)
    WT_2_shift_data_prev = WT_2_shift_data_prev.dropna(
        subset=['first', 'third'])
    WT_2_shifts_prev = WT_2_shift_data_prev['third'] - \
        WT_2_shift_data_prev['first']
    WT_2_shifts_prev[WT_2_shifts_prev < -np.pi] += 2 * np.pi
    WT_2_shifts_prev[WT_2_shifts_prev >= np.pi] -= 2 * np.pi
    WT_2_npc = np.isnan(WT_2_shift_data_prev['second'])

    sns.regplot(WT_2_shift_data_prev['first'][WT_2_npc],
                WT_2_shifts_prev[WT_2_npc],
                ax=axs[0, 0],
                color='m',
                fit_reg=False,
                scatter_kws={'s': 7},
                marker='x')
    sns.regplot(WT_2_shift_data['first'],
                WT_2_shifts,
                ax=axs[0, 0],
                color=WT_color,
                fit_reg=False,
                scatter_kws={'s': 3},
                marker=WT_marker)
    axs[0, 0].axvline(ls='--', color='0.4', lw=0.5)
    axs[0, 0].axhline(ls='--', color='0.4', lw=0.5)
    axs[0, 0].plot([-np.pi, np.pi], [np.pi, -np.pi], color='g', ls=':', lw=2)
    axs[0, 0].tick_params(length=3, pad=1, top=False)
    axs[0, 0].set_xlabel('Initial distance from reward\n(fraction of belt)')
    axs[0, 0].set_ylabel(r'Two-session $\Delta$ position' +
                         '\n(fraction of belt)')
    axs[0, 0].set_xlim(-np.pi, np.pi)
    axs[0, 0].set_xticks([-np.pi, -np.pi / 2, 0, np.pi / 2, np.pi])
    axs[0, 0].set_xticklabels(['-0.50', '-0.25', '0', '0.25', '0.50'])
    axs[0, 0].set_ylim(-np.pi, np.pi)
    axs[0, 0].set_yticks([-np.pi, -np.pi / 2, 0, np.pi / 2, np.pi])
    axs[0, 0].set_yticklabels(['-0.50', '-0.25', '0', '0.25', '0.50'])
    axs[0, 0].set_title(WT_label)

    Df_2_shift_data = emd.tripled_activity_centroid_distance_to_reward(
        Df_data, prev_imaged=False)
    Df_2_shift_data = Df_2_shift_data.dropna(subset=['first', 'third'])
    Df_2_shifts = Df_2_shift_data['third'] - Df_2_shift_data['first']
    Df_2_shifts[Df_2_shifts < -np.pi] += 2 * np.pi
    Df_2_shifts[Df_2_shifts >= np.pi] -= 2 * np.pi

    Df_2_shift_data_prev = emd.tripled_activity_centroid_distance_to_reward(
        Df_data, prev_imaged=True)
    Df_2_shift_data_prev = Df_2_shift_data_prev.dropna(
        subset=['first', 'third'])
    Df_2_shifts_prev = Df_2_shift_data_prev['third'] - \
        Df_2_shift_data_prev['first']
    Df_2_shifts_prev[Df_2_shifts_prev < -np.pi] += 2 * np.pi
    Df_2_shifts_prev[Df_2_shifts_prev >= np.pi] -= 2 * np.pi
    Df_2_npc = np.isnan(Df_2_shift_data_prev['second'])

    sns.regplot(Df_2_shift_data_prev['first'][Df_2_npc],
                Df_2_shifts_prev[Df_2_npc],
                ax=axs[0, 1],
                color='c',
                fit_reg=False,
                scatter_kws={'s': 7},
                marker='x')
    sns.regplot(Df_2_shift_data['first'],
                Df_2_shifts,
                ax=axs[0, 1],
                color=Df_color,
                fit_reg=False,
                scatter_kws={'s': 3},
                marker=Df_marker)
    axs[0, 1].axvline(ls='--', color='0.4', lw=0.5)
    axs[0, 1].axhline(ls='--', color='0.4', lw=0.5)
    axs[0, 1].plot([-np.pi, np.pi], [np.pi, -np.pi], color='g', ls=':', lw=2)
    axs[0, 1].tick_params(length=3, pad=1, top=False)
    axs[0, 1].set_xlabel('Initial distance from reward\n(fraction of belt)')
    axs[0, 1].set_ylabel(r'Two-session $\Delta$ position' +
                         '\n(fraction of belt)')
    axs[0, 1].set_xlim(-np.pi, np.pi)
    axs[0, 1].set_xticks([-np.pi, -np.pi / 2, 0, np.pi / 2, np.pi])
    axs[0, 1].set_xticklabels(['-0.50', '-0.25', '0', '0.25', '0.50'])
    axs[0, 1].set_ylim(-np.pi, np.pi)
    axs[0, 1].set_yticks([-np.pi, -np.pi / 2, 0, np.pi / 2, np.pi])
    axs[0, 1].set_yticklabels(['-0.50', '-0.25', '0', '0.25', '0.50'])
    axs[0, 1].set_title(Df_label)

    #
    # Place field stability k
    #

    x_vals = np.linspace(-np.pi, np.pi, 1000)

    WT_knots = WT_params['position_stability']['all_pairs']['knots']
    WT_spline = splines.CyclicSpline(WT_knots)
    WT_N = WT_spline.design_matrix(x_vals)

    WT_all_theta_k = WT_params['position_stability']['all_pairs'][
        'boots_theta_k']
    WT_skip_1_theta_k = WT_params['position_stability']['skip_one'][
        'boots_theta_k']
    WT_skip_npc_theta_k = WT_params['position_stability']['skip_one_npc'][
        'boots_theta_k']
    WT_two_iter_theta_k = WT_params['position_stability']['two_iter'][
        'boots_theta_k']

    WT_all_k_fit_mean = [
        1. / splines.get_k(theta_k, WT_N).mean() for theta_k in WT_all_theta_k
    ]
    WT_skip_1_k_fit_mean = [
        1. / splines.get_k(theta_k, WT_N).mean()
        for theta_k in WT_skip_1_theta_k
    ]
    WT_skip_npc_k_fit_mean = [
        1. / splines.get_k(theta_k, WT_N).mean()
        for theta_k in WT_skip_npc_theta_k
    ]
    WT_two_iter_k_fit_mean = [
        1. / splines.get_k(theta_k, WT_N).mean()
        for theta_k in WT_two_iter_theta_k
    ]

    WT_all_k_fit_df = pd.DataFrame({
        'value': WT_all_k_fit_mean,
        'genotype': WT_label
    })
    WT_skip_1_k_fit_df = pd.DataFrame({
        'value': WT_skip_1_k_fit_mean,
        'genotype': WT_label
    })
    WT_skip_npc_k_fit_df = pd.DataFrame({
        'value': WT_skip_npc_k_fit_mean,
        'genotype': WT_label
    })
    WT_two_iter_k_fit_df = pd.DataFrame({
        'value': WT_two_iter_k_fit_mean,
        'genotype': WT_label
    })

    Df_knots = Df_params['position_stability']['all_pairs']['knots']
    Df_spline = splines.CyclicSpline(Df_knots)
    Df_N = Df_spline.design_matrix(x_vals)

    Df_all_theta_k = Df_params['position_stability']['all_pairs'][
        'boots_theta_k']
    Df_skip_1_theta_k = Df_params['position_stability']['skip_one'][
        'boots_theta_k']
    Df_skip_npc_theta_k = Df_params['position_stability']['skip_one_npc'][
        'boots_theta_k']
    Df_two_iter_theta_k = Df_params['position_stability']['two_iter'][
        'boots_theta_k']

    Df_all_k_fit_mean = [
        1. / splines.get_k(theta_k, Df_N).mean() for theta_k in Df_all_theta_k
    ]
    Df_skip_1_k_fit_mean = [
        1. / splines.get_k(theta_k, Df_N).mean()
        for theta_k in Df_skip_1_theta_k
    ]
    Df_skip_npc_k_fit_mean = [
        1. / splines.get_k(theta_k, Df_N).mean()
        for theta_k in Df_skip_npc_theta_k
    ]
    Df_two_iter_k_fit_mean = [
        1. / splines.get_k(theta_b, Df_N).mean()
        for theta_b in Df_two_iter_theta_k
    ]

    Df_all_k_fit_df = pd.DataFrame({
        'value': Df_all_k_fit_mean,
        'genotype': Df_label
    })
    Df_skip_1_k_fit_df = pd.DataFrame({
        'value': Df_skip_1_k_fit_mean,
        'genotype': Df_label
    })
    Df_skip_npc_k_fit_df = pd.DataFrame({
        'value': Df_skip_npc_k_fit_mean,
        'genotype': Df_label
    })
    Df_two_iter_k_fit_df = pd.DataFrame({
        'value': Df_two_iter_k_fit_mean,
        'genotype': Df_label
    })

    all_k_df = pd.concat([WT_all_k_fit_df, Df_all_k_fit_df])
    skip_1_k_df = pd.concat([WT_skip_1_k_fit_df, Df_skip_1_k_fit_df])
    skip_npc_k_df = pd.concat([WT_skip_npc_k_fit_df, Df_skip_npc_k_fit_df])
    two_iter_k_df = pd.concat([WT_two_iter_k_fit_df, Df_two_iter_k_fit_df])

    order_dict = {WT_label: 0, Df_label: 1}
    all_k_df['order'] = all_k_df['genotype'].apply(order_dict.get)
    skip_1_k_df['order'] = skip_1_k_df['genotype'].apply(order_dict.get)
    skip_npc_k_df['order'] = skip_npc_k_df['genotype'].apply(order_dict.get)
    two_iter_k_df['order'] = two_iter_k_df['genotype'].apply(order_dict.get)

    plotting.plot_dataframe(
        axs[0, 2], [all_k_df, skip_1_k_df, skip_npc_k_df, two_iter_k_df],
        labels=[
            '1 ses elapsed', '2 ses elapsed (all)', '2 ses elapsed (nPC)',
            '2 iterations (model)'
        ],
        activity_label='Mean shift variance',
        colors=sns.color_palette('deep'),
        plotby=['genotype'],
        orderby='order',
        plot_method='grouped_bar',
        plot_shuffle=False,
        shuffle_plotby=False,
        pool_shuffle=False,
        agg_fn=np.mean,
        markers=None,
        label_groupby=True,
        z_score=False,
        normalize=False,
        error_bars='std')
    axs[0, 2].set_xlabel('')
    axs[0, 2].set_ylim(0, 1.5)
    y_ticks = np.array(['0', '0.01', '0.02', '0.03'])
    axs[0, 2].set_yticks(y_ticks.astype('float') * (2 * np.pi)**2)
    axs[0, 2].set_yticklabels(y_ticks)
    axs[0, 2].tick_params(top=False, bottom=False)
    axs[0, 2].set_title('')

    save_figure(fig, filename, save_dir=save_dir)
def main():

    raw_data, data = emd.load_data('wt',
                                   session_filter='C',
                                   root=os.path.join(df.data_path,
                                                     'enrichment_model'))
    expts = lab.ExperimentSet(os.path.join(df.metadata_path,
                                           'expt_metadata.xml'),
                              behaviorDataPath=os.path.join(
                                  df.data_path, 'behavior'),
                              dataPath=os.path.join(df.data_path, 'imaging'))

    params = pickle.load(open(params_path))

    fig = plt.figure(figsize=(8.5, 11))
    gs1 = plt.GridSpec(1,
                       2,
                       left=0.1,
                       bottom=0.65,
                       right=0.9,
                       top=0.9,
                       wspace=0.05)
    fov1_ax = fig.add_subplot(gs1[0, 0])
    fov2_ax = fig.add_subplot(gs1[0, 1])
    cmap_ax = fig.add_axes([0.49, 0.65, 0.02, 0.25])
    gs2 = plt.GridSpec(2,
                       2,
                       left=0.1,
                       bottom=0.3,
                       right=0.5,
                       top=0.6,
                       wspace=0.5,
                       hspace=0.5)
    recur_ax = fig.add_subplot(gs2[0, 0])
    shift_ax = fig.add_subplot(gs2[0, 1])
    shift_compare_ax = fig.add_subplot(gs2[1, 0])
    var_compare_ax = fig.add_subplot(gs2[1, 1])

    #
    # Tuning maps
    #

    e1 = expts.grabExpt('jz135', '2015-10-12-14h33m47s')
    e2 = expts.grabExpt('jz135', '2015-10-12-15h34m38s')

    cmap = mpl.colors.ListedColormap(sns.color_palette("husl", 256))

    for ax, expt in ((fov1_ax, e1), (fov2_ax, e2)):
        place.plot_spatial_tuning_overlay(ax,
                                          lab.classes.pcExperimentGroup(
                                              [expt], imaging_label='soma'),
                                          labels_visible=False,
                                          alpha=0.9,
                                          lw=0.1,
                                          cmap=cmap)
        plot_ROI_outlines(ax,
                          expt,
                          channel='Ch2',
                          label='soma',
                          roi_filter=None,
                          ls='-',
                          color='k',
                          lw=0.1)
        # Add a 50-um scale bar
        plotting.add_scalebar(
            ax=ax,
            matchx=False,
            matchy=False,
            sizey=0,
            sizex=50 / expt.imagingParameters()['micronsPerPixel']['XAxis'],
            bar_color='w',
            bar_thickness=3)

    fov1_ax.set_title('Session 1')
    fov2_ax.set_title('Session 2')

    gradient = np.linspace(0, 1, 256)
    gradient = np.vstack((gradient, gradient)).T
    cmap_ax.imshow(gradient, aspect='auto', cmap=cmap)
    sns.despine(ax=cmap_ax, top=True, left=True, right=True, bottom=True)
    cmap_ax.tick_params(left=False,
                        labelleft=False,
                        bottom=False,
                        labelbottom=False)
    cmap_ax.set_ylabel('belt position')

    # Figure out the reward window width
    reward_poss, windows = [], []
    for expt in [e1, e2]:
        reward_poss.append(expt.rewardPositions(units='normalized')[0])
        track_length = expt[0].behaviorData()['trackLength']
        window = float(expt.get('operantSpatialWindow'))
        windows.append(window / track_length)
    reward_pos = np.mean(reward_poss)
    window = np.mean(windows)

    # Add reward zone
    cmap_ax.plot([0, 1], [reward_pos, reward_pos],
                 transform=cmap_ax.transAxes,
                 color='k',
                 ls=':')
    cmap_ax.plot([0, 1], [reward_pos + window, reward_pos + window],
                 transform=cmap_ax.transAxes,
                 color='k',
                 ls=':')
    cmap_ax.set_ylim(0, 256)

    #
    # Recurrence by position
    #

    recur_x_vals = np.linspace(-np.pi, np.pi, 1000)

    recur_data = emd.recurrence_by_position(data, method='cv')
    recur_knots = np.linspace(-np.pi, np.pi,
                              params['position_recurrence']['n_knots'])
    recur_splines = splines.CyclicSpline(recur_knots)
    recur_n = recur_splines.design_matrix(recur_x_vals)

    recur_fit = splines.prob(params['position_recurrence']['theta'], recur_n)

    recur_boots_fits = [
        splines.prob(boot, recur_n)
        for boot in params['position_recurrence']['boots_theta']
    ]
    recur_ci_up_fit = np.percentile(recur_boots_fits, 95, axis=0)
    recur_ci_low_fit = np.percentile(recur_boots_fits, 5, axis=0)

    recur_ax.plot(recur_x_vals, recur_fit, color=WT_color)
    recur_ax.fill_between(recur_x_vals,
                          recur_ci_low_fit,
                          recur_ci_up_fit,
                          facecolor=WT_color,
                          alpha=0.5)
    sns.regplot(recur_data[:, 0],
                recur_data[:, 1],
                ax=recur_ax,
                color=WT_color,
                y_jitter=0.2,
                fit_reg=False,
                scatter_kws={'s': 1},
                marker=WT_marker)
    recur_ax.axvline(ls='--', color='0.4', lw=0.5)
    recur_ax.set_xlim(-np.pi, np.pi)
    recur_ax.set_xticks([-np.pi, -np.pi / 2, 0, np.pi / 2, np.pi])
    recur_ax.set_xticklabels(['-0.50', '-0.25', '0', '0.25', '0.50'])
    recur_ax.set_ylim(-0.3, 1.3)
    recur_ax.set_yticks([0, 0.5, 1])
    recur_ax.tick_params(length=3, pad=1, top=False)
    recur_ax.set_xlabel('Initial distance from reward (fraction of belt)')
    recur_ax.set_ylabel('Place cell recurrence probability')
    recur_ax.set_title('')
    recur_ax_2 = recur_ax.twinx()
    recur_ax_2.tick_params(length=3, pad=1, top=False)
    recur_ax_2.set_ylim(-0.3, 1.3)
    recur_ax_2.set_yticks([0, 1])
    recur_ax_2.set_yticklabels(['non-recur', 'recur'])

    #
    # Place field stability
    #

    shift_x_vals = np.linspace(-np.pi, np.pi, 1000)

    shift_knots = params['position_stability']['all_pairs']['knots']
    shift_spline = splines.CyclicSpline(shift_knots)
    shift_n = shift_spline.design_matrix(shift_x_vals)
    shift_theta_b = params['position_stability']['all_pairs']['theta_b']
    shift_b_fit = np.dot(shift_n, shift_theta_b)
    shift_theta_k = params['position_stability']['all_pairs']['theta_k']
    shift_k_fit = splines.get_k(shift_theta_k, shift_n)
    shift_fit_var = 1. / shift_k_fit

    shift_data = emd.paired_activity_centroid_distance_to_reward(data)
    shift_data = shift_data.dropna()
    shifts = shift_data['second'] - shift_data['first']
    shifts[shifts < -np.pi] += 2 * np.pi
    shifts[shifts >= np.pi] -= 2 * np.pi

    shift_ax.plot(shift_x_vals, shift_b_fit, color=WT_color)
    shift_ax.fill_between(shift_x_vals,
                          shift_b_fit - shift_fit_var,
                          shift_b_fit + shift_fit_var,
                          facecolor=WT_color,
                          alpha=0.5)
    sns.regplot(shift_data['first'],
                shifts,
                ax=shift_ax,
                color=WT_color,
                fit_reg=False,
                scatter_kws={'s': 1},
                marker=WT_marker)

    shift_ax.axvline(ls='--', color='0.4', lw=0.5)
    shift_ax.axhline(ls='--', color='0.4', lw=0.5)
    shift_ax.plot([-np.pi, np.pi], [np.pi, -np.pi], color='g', ls=':', lw=2)
    shift_ax.tick_params(length=3, pad=1, top=False)
    shift_ax.set_xlabel('Initial distance from reward (fraction of belt)')
    shift_ax.set_ylabel(r'$\Delta$ position (fraction of belt)')
    shift_ax.set_xlim(-np.pi, np.pi)
    shift_ax.set_xticks([-np.pi, -np.pi / 2, 0, np.pi / 2, np.pi])
    shift_ax.set_xticklabels(['-0.50', '-0.25', '0', '0.25', '0.50'])
    shift_ax.set_ylim(-np.pi, np.pi)
    shift_ax.set_yticks([-np.pi, -np.pi / 2, 0, np.pi / 2, np.pi])
    shift_ax.set_yticklabels(['-0.50', '-0.25', '0', '0.25', '0.50'])
    shift_ax.set_title('')

    #
    # Stability by distance to reward
    #

    shift_x_vals = np.linspace(-np.pi, np.pi, 1000)

    shift_knots = params['position_stability']['all_pairs']['knots']
    shift_spline = splines.CyclicSpline(shift_knots)
    shift_n = shift_spline.design_matrix(shift_x_vals)

    shift_theta_b = params['position_stability']['all_pairs']['theta_b']
    shift_b_fit = np.dot(shift_n, shift_theta_b)
    shift_boots_b_fit = [
        np.dot(shift_n, boot)
        for boot in params['position_stability']['all_pairs']['boots_theta_b']
    ]
    shift_b_ci_up_fit = np.percentile(shift_boots_b_fit, 95, axis=0)
    shift_b_ci_low_fit = np.percentile(shift_boots_b_fit, 5, axis=0)

    shift_compare_ax.plot(shift_x_vals, shift_b_fit, color=WT_color)
    shift_compare_ax.fill_between(shift_x_vals,
                                  shift_b_ci_low_fit,
                                  shift_b_ci_up_fit,
                                  facecolor=WT_color,
                                  alpha=0.5)

    shift_compare_ax.axvline(ls='--', color='0.4', lw=0.5)
    shift_compare_ax.axhline(ls='--', color='0.4', lw=0.5)
    shift_compare_ax.tick_params(length=3, pad=1, top=False)
    shift_compare_ax.set_xlim(-np.pi, np.pi)
    shift_compare_ax.set_xticks([-np.pi, -np.pi / 2, 0, np.pi / 2, np.pi])
    shift_compare_ax.set_xticklabels(['-0.50', '-0.25', '0', '0.25', '0.50'])
    shift_compare_ax.set_ylim(-0.10 * 2 * np.pi, 0.10 * 2 * np.pi)
    y_ticks = np.array(['-0.10', '-0.05', '0', '0.05', '0.10'])
    shift_compare_ax.set_yticks(y_ticks.astype('float') * 2 * np.pi)
    shift_compare_ax.set_yticklabels(y_ticks)
    shift_compare_ax.set_xlabel(
        'Initial distance from reward (fraction of belt)')
    shift_compare_ax.set_ylabel(r'$\Delta$ position (fraction of belt)')

    shift_theta_k = params['position_stability']['all_pairs']['theta_k']
    shift_k_fit = splines.get_k(shift_theta_k, shift_n)
    shift_boots_k_fit = [
        splines.get_k(boot, shift_n)
        for boot in params['position_stability']['all_pairs']['boots_theta_k']
    ]
    shift_k_ci_up_fit = np.percentile(shift_boots_k_fit, 95, axis=0)
    shift_k_ci_low_fit = np.percentile(shift_boots_k_fit, 5, axis=0)

    var_compare_ax.plot(shift_x_vals, 1. / shift_k_fit, color=WT_color)
    var_compare_ax.fill_between(shift_x_vals,
                                1. / shift_k_ci_low_fit,
                                1. / shift_k_ci_up_fit,
                                facecolor=WT_color,
                                alpha=0.5)

    var_compare_ax.axvline(ls='--', color='0.4', lw=0.5)
    var_compare_ax.tick_params(length=3, pad=1, top=False)
    var_compare_ax.set_xlim(-np.pi, np.pi)
    var_compare_ax.set_xticks([-np.pi, -np.pi / 2, 0, np.pi / 2, np.pi])
    var_compare_ax.set_xticklabels(['-0.50', '-0.25', '0', '0.25', '0.50'])
    y_ticks = np.array(['0.005', '0.010', '0.015', '0.020'])
    var_compare_ax.set_yticks(y_ticks.astype('float') * (2 * np.pi)**2)
    var_compare_ax.set_yticklabels(y_ticks)
    var_compare_ax.set_xlabel(
        'Initial distance from reward (fraction of belt)')
    var_compare_ax.set_ylabel(r'$\Delta$ position variance')

    misc.save_figure(fig, filename, save_dir=save_dir)

    plt.close('all')
def main():
    all_grps = df.loadExptGrps('GOL')
    expts = lab.ExperimentSet(os.path.join(df.metadata_path,
                                           'expt_metadata.xml'),
                              behaviorDataPath=os.path.join(
                                  df.data_path, 'behavior'),
                              dataPath=os.path.join(df.data_path, 'imaging'))

    WT_expt_grp = all_grps['WT_hidden_behavior_set']
    Df_expt_grp = all_grps['Df_hidden_behavior_set']
    expt_grps = [WT_expt_grp, Df_expt_grp]

    if MALES_ONLY:
        for expt_grp in expt_grps:
            expt_grp.filter(lambda expt: expt.parent.get('sex') == 'M')
    labels = [expt_grp.label() for expt_grp in expt_grps]

    fig = plt.figure(figsize=(8.5, 11))

    HORIZONTAL = False

    if HORIZONTAL:
        gs1 = plt.GridSpec(8, 6)
        wt_lick_axs = [
            fig.add_subplot(gs1[0, 0]),
            fig.add_subplot(gs1[0, 1]),
            fig.add_subplot(gs1[0, 2]),
            fig.add_subplot(gs1[0, 3]),
            fig.add_subplot(gs1[0, 4]),
            fig.add_subplot(gs1[0, 5])
        ]
        df_lick_axs = [
            fig.add_subplot(gs1[1, 0]),
            fig.add_subplot(gs1[1, 1]),
            fig.add_subplot(gs1[1, 2]),
            fig.add_subplot(gs1[1, 3]),
            fig.add_subplot(gs1[1, 4]),
            fig.add_subplot(gs1[1, 5])
        ]

        gs2 = plt.GridSpec(4, 2, hspace=0.5, wspace=0.2)
        reward_zone_ax = fig.add_subplot(gs2[1, 0])
    else:
        gs1 = plt.GridSpec(10, 6)
        wt_lick_axs = [
            fig.add_subplot(gs1[0, 0]),
            fig.add_subplot(gs1[1, 0]),
            fig.add_subplot(gs1[2, 0]),
            fig.add_subplot(gs1[3, 0]),
            fig.add_subplot(gs1[4, 0]),
            fig.add_subplot(gs1[5, 0])
        ]
        df_lick_axs = [
            fig.add_subplot(gs1[0, 1]),
            fig.add_subplot(gs1[1, 1]),
            fig.add_subplot(gs1[2, 1]),
            fig.add_subplot(gs1[3, 1]),
            fig.add_subplot(gs1[4, 1]),
            fig.add_subplot(gs1[5, 1])
        ]

        gs2 = plt.GridSpec(10, 1, hspace=0.5, wspace=0.8, left=0.47, right=0.9)
        reward_zone_ax = fig.add_subplot(gs2[0:4, :])
        gs3 = plt.GridSpec(10, 3, hspace=0.5, wspace=0.1, left=0.47, right=0.9)
        fraction_licks_by_session_A_ax = fig.add_subplot(gs3[5:7, 0])
        fraction_licks_by_session_B_ax = fig.add_subplot(gs3[5:7, 1])
        fraction_licks_by_session_C_ax = fig.add_subplot(gs3[5:7, 2])

    #
    # Lick plots
    #

    wt_lick_expts = [
        expts.grabExpt('jz101', '2014-11-06-23h37m54s'),
        expts.grabExpt('jz101', '2014-11-08-22h53m27s'),
        expts.grabExpt('jz101', '2014-11-09-23h06m56s'),
        expts.grabExpt('jz101', '2014-11-11-23h13m16s'),
        expts.grabExpt('jz101', '2014-11-12-19h29m41s'),
        expts.grabExpt('jz101', '2014-11-14-19h59m09s')
    ]
    df_lick_expts = [
        expts.grabExpt('jz106', '2014-12-11-17h06m49s'),
        expts.grabExpt('jz106', '2014-12-13-19h00m01s'),
        expts.grabExpt('jz106', '2014-12-14-17h17m17s'),
        expts.grabExpt('jz106', '2014-12-16-17h43m05s'),
        expts.grabExpt('jz106', '2014-12-17-17h57m51s'),
        expts.grabExpt('jz106', '2014-12-19-17h13m52s')
    ]

    shade_color = sns.xkcd_rgb['light green']
    for ax, expt in zip(wt_lick_axs, wt_lick_expts):
        expt.licktogram(ax=ax,
                        plot_belt=False,
                        nPositionBins=20,
                        color=WT_color,
                        linewidth=0,
                        shade_reward=True,
                        shade_color=shade_color)
    for ax, expt in zip(df_lick_axs, df_lick_expts):
        expt.licktogram(ax=ax,
                        plot_belt=False,
                        nPositionBins=20,
                        color=Df_color,
                        linewidth=0,
                        shade_reward=True,
                        shade_color=shade_color)

    for ax in wt_lick_axs + df_lick_axs:
        ax.set_ylim(0, 0.6)
        ax.set_yticks([0, 0.3, 0.6])
        ax.set_xticks([0, 0.5, 1])
        ax.set_xticklabels(['0.0', '0.5', '1.0'])
        sns.despine(ax=ax)
        ax.set_title('')

    if HORIZONTAL:
        for ax in wt_lick_axs:
            ax.tick_params(labelbottom=False)
            ax.set_xlabel('')

        for ax in df_lick_axs[1:]:
            ax.set_xlabel('')

        for ax, label in zip(wt_lick_axs, [
                r'Condition $\mathrm{I}$' + '\nDay 1',
                r'Condition $\mathrm{I}$' + '\nDay 3',
                r'Condition $\mathrm{II}$' + '\nDay 1',
                r'Condition $\mathrm{II}$' + '\nDay 3',
                r'Condition $\mathrm{III}$' + '\nDay 1',
                r'Condition $\mathrm{III}$' + '\nDay 3'
        ]):
            ax.set_title(label)

        for ax in it.chain(wt_lick_axs[1:], df_lick_axs[1:]):
            ax.set_ylabel('')
            sns.despine(ax=ax, left=True, top=True, right=True)

        for ax in wt_lick_axs + df_lick_axs:
            ax.spines['bottom'].set_linewidth(0.5)

        wt_lick_axs[0].tick_params(labelbottom=False)
        for ax in wt_lick_axs[1:]:
            ax.tick_params(labelleft=False, left=False, labelbottom=False)
        for ax in df_lick_axs[1:]:
            ax.tick_params(labelleft=False, left=False)

        right_label(wt_lick_axs[-1], labels[0])
        right_label(df_lick_axs[-1], labels[1])

        df_lick_axs[0].set_yticks([0, 0.6])
        wt_lick_axs[0].set_yticks([0, 0.6])
        df_lick_axs[0].set_ylabel('Fraction of licks')
        wt_lick_axs[0].set_ylabel('')
        df_lick_axs[0].spines['left'].set_linewidth(0.5)
        wt_lick_axs[0].spines['left'].set_linewidth(0.5)
    else:
        for ax in wt_lick_axs + df_lick_axs:
            ax.spines['bottom'].set_linewidth(0.5)
        for ax in wt_lick_axs + df_lick_axs[1:]:
            sns.despine(ax=ax, left=True, top=True, right=True)
        sns.despine(ax=df_lick_axs[0], left=True, top=True, right=False)

        for ax in wt_lick_axs[:-1] + df_lick_axs[:-1]:
            ax.set_xlabel('')

        for ax in df_lick_axs:
            ax.set_ylabel('')

        for ax, label in zip(wt_lick_axs, [
                r'Condition $\mathrm{I}$' + '\nDay 1',
                r'Condition $\mathrm{I}$' + '\nDay 3',
                r'Condition $\mathrm{II}$' + '\nDay 1',
                r'Condition $\mathrm{II}$' + '\nDay 3',
                r'Condition $\mathrm{III}$' + '\nDay 1',
                r'Condition $\mathrm{III}$' + '\nDay 3'
        ]):
            ax.set_ylabel(label,
                          rotation='horizontal',
                          ha='right',
                          multialignment='center',
                          labelpad=3,
                          va='center')

        for ax in wt_lick_axs[:-1] + df_lick_axs[:-1]:
            ax.tick_params(labelleft=False, left=False, labelbottom=False)

        for ax in (wt_lick_axs[-1], df_lick_axs[-1]):
            ax.tick_params(labelleft=False, left=False)

        wt_lick_axs[0].set_title(labels[0])
        df_lick_axs[0].set_title(labels[1])

        df_lick_axs[0].yaxis.tick_right()
        df_lick_axs[0].yaxis.set_label_position("right")
        df_lick_axs[0].set_yticks([0, 0.6])
        df_lick_axs[0].tick_params(axis='y', length=2, pad=2, direction='in')
        df_lick_axs[0].set_ylabel('Fraction of licks')
        df_lick_axs[0].spines['right'].set_linewidth(0.5)

    filter_fn = None
    filter_columns = None

    behavior_fn = ra.fraction_licks_in_reward_zone
    behavior_kwargs = {}
    activity_label = 'Fraction of licks in reward zone'

    plot_metric(reward_zone_ax,
                expt_grps,
                metric_fn=behavior_fn,
                activity_kwargs=behavior_kwargs,
                groupby=[['expt'], ['mouseID', 'X_condition', 'X_day']],
                plotby=['X_condition', 'X_day'],
                plot_method='line',
                activity_label=activity_label,
                colors=colors,
                linestyles=linestyles,
                label_every_n=1,
                label_groupby=False,
                markers=markers,
                markersize=5,
                rotate_labels=False,
                filter_fn=filter_fn,
                filter_columns=filter_columns,
                return_full_dataframes=False)
    reward_zone_ax.set_yticks([0, .1, .2, .3, .4])
    sns.despine(ax=reward_zone_ax)
    reward_zone_ax.set_xlabel('Day in Condition')
    reward_zone_ax.set_title('')
    day_number_only_label(reward_zone_ax)
    label_conditions(reward_zone_ax)
    reward_zone_ax.legend(loc='lower left', fontsize=8)
    # reward_zone_ax.get_legend().set_visible(False)
    # stackedText(reward_zone_ax, labels, colors=colors, loc=3, size=10)

    groupby = [['expt']]
    plotby = ['X_condition', 'X_session']

    filter_fn = lambda df: (df['X_session'] != '1') & (df['X_condition'] == 'A'
                                                       )
    filter_columns = ['X_session', 'X_condition']
    line_kwargs = {'markersize': 4}
    plot_metric(fraction_licks_by_session_A_ax,
                expt_grps,
                metric_fn=behavior_fn,
                activity_kwargs=behavior_kwargs,
                groupby=groupby,
                plotby=plotby,
                plot_method='box_and_line',
                activity_label=activity_label,
                colors=colors,
                notch=False,
                label_every_n=1,
                label_groupby=False,
                markers=markers,
                rotate_labels=False,
                line_kwargs=line_kwargs,
                linestyles=linestyles,
                filter_fn=filter_fn,
                filter_columns=filter_columns,
                flierprops={
                    'markersize': 2,
                    'marker': 'o'
                },
                box_width=0.4,
                box_spacing=0.2,
                return_full_dataframes=False,
                whis='range')
    sns.despine(ax=fraction_licks_by_session_A_ax, top=True, right=True)
    fraction_licks_by_session_A_ax.set_xticklabels(['first', 'last'])
    fraction_licks_by_session_A_ax.set_xlabel('')
    fraction_licks_by_session_A_ax.set_ylim(-0.02, 0.6)
    fraction_licks_by_session_A_ax.set_yticks([0, 0.2, 0.4, 0.6])
    fraction_licks_by_session_A_ax.set_title('')
    fraction_licks_by_session_A_ax.legend(loc='upper left', fontsize=6)
    # fraction_licks_by_session_A_ax.get_legend().set_visible(False)
    fraction_licks_by_session_A_ax.text(
        0.5,
        .95,
        r'$\mathrm{I}$',
        ha='center',
        va='center',
        transform=fraction_licks_by_session_A_ax.transAxes,
        fontsize=12)

    filter_fn = lambda df: (df['X_session'] != '1') & (df['X_condition'] == 'B'
                                                       )
    filter_columns = ['X_session', 'X_condition']
    plot_metric(fraction_licks_by_session_B_ax,
                expt_grps,
                metric_fn=behavior_fn,
                activity_kwargs=behavior_kwargs,
                groupby=groupby,
                plotby=plotby,
                plot_method='box_and_line',
                activity_label=activity_label,
                colors=colors,
                label_every_n=1,
                label_groupby=False,
                markers=markers,
                rotate_labels=False,
                line_kwargs=line_kwargs,
                linestyles=linestyles,
                filter_fn=filter_fn,
                filter_columns=filter_columns,
                notch=False,
                flierprops={
                    'markersize': 2,
                    'marker': 'o'
                },
                box_width=0.4,
                box_spacing=0.2,
                return_full_dataframes=False,
                whis='range')
    sns.despine(ax=fraction_licks_by_session_B_ax,
                top=True,
                right=True,
                left=True)
    fraction_licks_by_session_B_ax.tick_params(left=False, labelleft=False)
    fraction_licks_by_session_B_ax.set_xticklabels(['first', 'last'])
    fraction_licks_by_session_B_ax.set_xlabel('Session in day')
    fraction_licks_by_session_B_ax.set_ylabel('')
    fraction_licks_by_session_B_ax.set_ylim(-0.02, 0.6)
    fraction_licks_by_session_B_ax.set_title('')
    fraction_licks_by_session_B_ax.get_legend().set_visible(False)
    fraction_licks_by_session_B_ax.text(
        0.5,
        .95,
        r'$\mathrm{II}$',
        ha='center',
        va='center',
        transform=fraction_licks_by_session_B_ax.transAxes,
        fontsize=12)

    filter_fn = lambda df: (df['X_session'] != '1') & (df['X_condition'] == 'C'
                                                       )
    filter_columns = ['X_session', 'X_condition']
    plot_metric(fraction_licks_by_session_C_ax,
                expt_grps,
                metric_fn=behavior_fn,
                activity_kwargs=behavior_kwargs,
                groupby=groupby,
                plotby=plotby,
                plot_method='box_and_line',
                activity_label=activity_label,
                colors=colors,
                notch=False,
                label_every_n=1,
                label_groupby=False,
                markers=markers,
                rotate_labels=False,
                line_kwargs=line_kwargs,
                linestyles=linestyles,
                filter_fn=filter_fn,
                filter_columns=filter_columns,
                return_full_dataframes=False,
                flierprops={
                    'markersize': 2,
                    'marker': 'o'
                },
                box_width=0.4,
                box_spacing=0.2,
                whis='range')
    sns.despine(ax=fraction_licks_by_session_C_ax,
                top=True,
                right=True,
                left=True)
    fraction_licks_by_session_C_ax.tick_params(left=False, labelleft=False)
    fraction_licks_by_session_C_ax.set_xticklabels(['first', 'last'])
    fraction_licks_by_session_C_ax.set_xlabel('')
    fraction_licks_by_session_C_ax.set_ylabel('')
    fraction_licks_by_session_C_ax.set_ylim(-0.02, 0.6)
    fraction_licks_by_session_C_ax.set_title('')
    fraction_licks_by_session_C_ax.get_legend().set_visible(False)
    fraction_licks_by_session_C_ax.text(
        0.5,
        .95,
        r'$\mathrm{III}$',
        ha='center',
        va='center',
        transform=fraction_licks_by_session_C_ax.transAxes,
        fontsize=12)

    misc.save_figure(fig, filename, save_dir=save_dir)

    plt.close('all')
def main():
    all_grps = df.loadExptGrps('GOL')

    WT_expt_grp = all_grps['WT_place_set']
    Df_expt_grp = all_grps['Df_place_set']
    expt_grps = [WT_expt_grp, Df_expt_grp]
    if MALES_ONLY:
        for expt_grp in expt_grps:
            expt_grp.filter(lambda expt: expt.parent.get('sex') == 'M')

    fig = plt.figure(figsize=(8.5, 11))

    recur_cdf_ax = fig.add_axes([0.1, 0.70, 0.22, 0.20])
    recur_inset_ax = fig.add_axes([0.26, 0.72, 0.05, 0.08])
    recur_day_sess_comp_ax = fig.add_axes([0.41, 0.70, 0.15, 0.20])
    recur_behav_corr_ax = fig.add_axes([0.65, 0.70, 0.27, 0.20])

    stability_over_time_ax = fig.add_axes([0.1, 0.40, 0.22, 0.20])
    stability_inset_ax = fig.add_axes([0.26, 0.42, 0.05, 0.08])
    stab_day_sess_comp_ax = fig.add_axes([0.41, 0.40, 0.15, 0.20])
    stability_behav_corr_ax = fig.add_axes([0.65, 0.40, 0.27, 0.20])

    params = {}
    params['recur_range'] = (-0.05, 0.85)
    params['recur_xticks'] = [0.0, 0.2, 0.4, 0.6, 0.8]

    fraction_in_params = {}
    fraction_in_params['behavior_fn'] = ra.fraction_licks_in_reward_zone
    fraction_in_params['behavior_kwargs'] = {}
    fraction_in_params['behavior_label'] = 'Fraction of licks in reward zone'
    fraction_in_params['behav_range'] = (-0.05, 0.65)

    anticipatory_licking_params = {}
    anticipatory_licking_params['behavior_fn'] = ra.fraction_licks_near_rewards
    anticipatory_licking_params['behavior_kwargs'] = {
        'pre_window_cm': 5,
        'exclude_reward': True
    }
    anticipatory_licking_params['behavior_label'] = \
        'Anticipatory licking fraction'
    anticipatory_licking_params['behav_range'] = (-0.05, 0.95)

    activity_centroid_params = {}
    activity_centroid_params['stability_fn'] = place.activity_centroid_shift
    activity_centroid_params['stability_kwargs'] = {
        'activity_filter': 'active_both',
        'circ_var_pcs': circ_var_pcs,
        'units': 'rad'
    }
    activity_centroid_params['stability_label'] = \
        'Centroid shift (fraction of belt)'
    activity_centroid_params['stab_range'] = (0.90, 2)

    act_cent_norm_params = {}
    act_cent_norm_params['stability_kwargs'] = {
        'activity_filter': 'active_both',
        'circ_var_pcs': circ_var_pcs,
        'units': 'norm'
    }
    act_cent_norm_params['stab_range'] = (0.15, 0.31)
    act_cent_norm_params['stab_xticks'] = (0.15, 0.20, 0.25, 0.30)
    act_cent_norm_params['stab_day_ses_ylim'] = (0.10, 0.3)
    act_cent_norm_params['stab_day_ses_yticks'] = (0.1, 0.15, 0.20, 0.25, 0.30)
    act_cent_norm_params['stab_inset_ylim'] = (0.15, 0.3)
    act_cent_norm_params['stab_inset_yticks'] = (0.15, 0.3)

    act_cent_pc_parms = {}
    act_cent_pc_parms['stability_kwargs'] = {
        'activity_filter': 'pc_both',
        'circ_var_pcs': circ_var_pcs,
        'units': 'rad'
    }

    act_cent_cm_params = {}
    act_cent_cm_params['stability_kwargs'] = {
        'activity_filter': 'active_both',
        'circ_var_pcs': circ_var_pcs,
        'units': 'cm'
    }
    act_cent_cm_params['stability_label'] = 'Centroid shift (cm)'
    act_cent_cm_params['stab_range'] = (25, 65)
    act_cent_cm_params['stab_inset_ylim'] = (0, 50)
    act_cent_cm_params['stab_inset_yticks'] = (0, 50)
    act_cent_cm_params['stab_day_ses_ylim'] = (30, 50)

    pf_corr_params = {}
    pf_corr_params['stability_fn'] = place.place_field_correlation
    pf_corr_params['stability_kwargs'] = {'activity_filter': 'pc_either'}
    pf_corr_params['stability_label'] = 'Place field correlation'
    pf_corr_params['stab_range'] = (-0.1, 0.5)
    pf_corr_params['stab_xticks'] = (-0.1, 0, 0.1, 0.2, 0.3, 0.4, 0.5)
    pf_corr_params['stab_inset_ylim'] = (0, 0.30)
    pf_corr_params['stab_inset_yticks'] = (0, 0.30)
    pf_corr_params['stab_day_ses_ylim'] = (0, 0.7)
    pf_corr_params['stab_day_ses_yticks'] = (0, 0.2, 0.4, 0.6)

    pop_vec_corr_params = {}
    pop_vec_corr_params['stability_fn'] = place.population_vector_correlation
    pop_vec_corr_params['stability_kwargs'] = {
        'method': 'corr',
        'activity_filter': 'pc_both',
        'min_pf_density': 0.05
    }
    pop_vec_corr_params['stability_label'] = 'Population vector correlation'

    #
    # Select parameters
    #
    params.update(fraction_in_params)
    params.update(activity_centroid_params)
    params.update(act_cent_norm_params)
    # params.update(pf_corr_params)  # For Supplemental Figure 5

    #
    # Recurrence
    #
    day_paired_grps = [
        grp.pair('same group',
                 groupby=['X_session']).pair('consecutive groups',
                                             groupby=['X_condition', 'X_day'])
        for grp in expt_grps
    ]
    session_paired_grps = [
        grp.pair('consecutive groups',
                 groupby=['X_condition', 'X_day', 'X_session'])
        for grp in expt_grps
    ]

    plotting.plot_metric(recur_cdf_ax,
                         day_paired_grps,
                         metric_fn=place.recurrence_probability,
                         groupby=(('second_expt', ), ),
                         plot_method='cdf',
                         roi_filters=roi_filters,
                         activity_kwargs={'circ_var_pcs': circ_var_pcs},
                         plot_shuffle=True,
                         shuffle_plotby=False,
                         pool_shuffle=True,
                         activity_label='Recurrence probability',
                         colors=colors,
                         rotate_labels=False,
                         return_full_dataframes=False)
    recur_cdf_ax.legend(loc='upper left', fontsize=6)
    recur_cdf_ax.set_title('')
    recur_cdf_ax.set_xlim(params['recur_range'])
    recur_cdf_ax.set_xticks(params['recur_xticks'])

    plotting.plot_metric(recur_inset_ax,
                         day_paired_grps,
                         metric_fn=place.recurrence_probability,
                         groupby=(('second_expt', 'second_mouseID'),
                                  ('second_mouseID', )),
                         plot_method='swarm',
                         roi_filters=roi_filters,
                         activity_kwargs={'circ_var_pcs': circ_var_pcs},
                         plot_shuffle=True,
                         shuffle_plotby=False,
                         pool_shuffle=True,
                         activity_label='Recurrence probability',
                         colors=colors,
                         rotate_labels=False,
                         plot_shuffle_as_hline=True,
                         linewidth=0.2,
                         edgecolor='gray')
    recur_inset_ax.get_legend().set_visible(False)
    sns.despine(ax=recur_inset_ax)
    recur_inset_ax.set_title('')
    recur_inset_ax.set_ylabel('')
    recur_inset_ax.set_xlabel('')
    recur_inset_ax.tick_params(bottom=False, labelbottom=False)
    recur_inset_ax.set_ylim(0, 1.)
    recur_inset_ax.set_yticks([0, 1.])

    #
    # Recurrence - behavior correlation
    #
    scatter_kws = {'s': 5}
    colorby_list = [(expt_grp.label(), ) for expt_grp in expt_grps]
    recur_behav_corr_ax.set_xlim(params['recur_range'])
    recur_behav_corr_ax.set_ylim(params['behav_range'])
    plotting.plot_paired_metrics(
        day_paired_grps,
        first_metric_fn=place.recurrence_probability,
        second_metric_fn=params['behavior_fn'],
        roi_filters=roi_filters,
        groupby=(('second_expt', ), ),
        colorby=['expt_grp'],
        filter_fn=None,
        filter_columns=None,
        first_metric_kwargs=None,
        second_metric_kwargs=params['behavior_kwargs'],
        first_metric_label='Recurrence probability',
        second_metric_label=params['behavior_label'],
        shuffle_colors=False,
        fit_reg=True,
        plot_method='regplot',
        colorby_list=colorby_list,
        colors=colors,
        markers=markers,
        ax=recur_behav_corr_ax,
        scatter_kws=scatter_kws,
        truncate=False,
        linestyles=linestyles)
    recur_behav_corr_ax.set_xlim(params['recur_range'])
    recur_behav_corr_ax.set_ylim(params['behav_range'])
    recur_behav_corr_ax.tick_params(direction='in')
    recur_behav_corr_ax.legend(loc='upper left', fontsize=6)

    #
    # Stability
    #
    filter_fn = None
    filter_columns = None
    plotting.plot_metric(stability_over_time_ax,
                         day_paired_grps,
                         metric_fn=params['stability_fn'],
                         groupby=(('second_expt', ), ),
                         plotby=None,
                         plot_method='cdf',
                         roi_filters=roi_filters,
                         activity_kwargs=params['stability_kwargs'],
                         plot_shuffle=True,
                         shuffle_plotby=False,
                         pool_shuffle=True,
                         activity_label=params['stability_label'],
                         colors=colors,
                         rotate_labels=False,
                         filter_fn=filter_fn,
                         filter_columns=filter_columns,
                         return_full_dataframes=False)
    stability_over_time_ax.legend(loc='upper left', fontsize=6)
    stability_over_time_ax.set_title('')
    stability_over_time_ax.set_xlabel(params['stability_label'])
    stability_over_time_ax.set_xlim(params['stab_range'])
    stability_over_time_ax.set_xticks(params['stab_xticks'])

    plotting.plot_metric(stability_inset_ax,
                         day_paired_grps,
                         metric_fn=params['stability_fn'],
                         groupby=(
                             ('second_expt', 'second_mouseID'),
                             ('second_mouseID', ),
                         ),
                         plotby=None,
                         plot_method='swarm',
                         roi_filters=roi_filters,
                         activity_kwargs=params['stability_kwargs'],
                         plot_shuffle=True,
                         shuffle_plotby=False,
                         pool_shuffle=True,
                         activity_label=params['stability_label'],
                         colors=colors,
                         rotate_labels=False,
                         filter_fn=filter_fn,
                         filter_columns=filter_columns,
                         plot_shuffle_as_hline=True,
                         linewidth=0.2,
                         edgecolor='gray')
    stability_inset_ax.get_legend().set_visible(False)
    sns.despine(ax=stability_inset_ax)
    stability_inset_ax.set_title('')
    stability_inset_ax.set_ylabel('')
    stability_inset_ax.set_xlabel('')
    stability_inset_ax.tick_params(bottom=False, labelbottom=False)
    stability_inset_ax.set_ylim(params['stab_inset_ylim'])
    stability_inset_ax.set_yticks(params['stab_inset_yticks'])

    #
    # Stability - behavior correlation
    #
    scatter_kws = {'s': 5}
    colorby_list = [(expt_grp.label(), ) for expt_grp in expt_grps]
    stability_behav_corr_ax.set_xlim(params['stab_range'])
    stability_behav_corr_ax.set_ylim(params['behav_range'])
    plotting.plot_paired_metrics(
        day_paired_grps,
        first_metric_fn=params['stability_fn'],
        second_metric_fn=params['behavior_fn'],
        roi_filters=roi_filters,
        groupby=(('second_expt', ), ),
        colorby=['expt_grp'],
        filter_fn=filter_fn,
        filter_columns=filter_columns,
        first_metric_kwargs=params['stability_kwargs'],
        second_metric_kwargs=params['behavior_kwargs'],
        first_metric_label=params['stability_label'],
        second_metric_label=params['behavior_label'],
        shuffle_colors=False,
        fit_reg=True,
        plot_method='regplot',
        colorby_list=colorby_list,
        colors=colors,
        markers=markers,
        ax=stability_behav_corr_ax,
        scatter_kws=scatter_kws,
        truncate=False,
        linestyles=linestyles)
    stability_behav_corr_ax.tick_params(direction='in')
    stability_behav_corr_ax.set_xlim(params['stab_range'])
    stability_behav_corr_ax.set_ylim(params['behav_range'])
    stability_behav_corr_ax.set_xticks(params['stab_xticks'])
    stability_behav_corr_ax.set_xlabel(params['stability_label'])
    stability_behav_corr_ax.legend(loc='upper right', fontsize=6)

    #
    # Day vs. session elapsed comparison
    #
    filter_fn = None
    filter_columns = None
    line_kwargs = {'markersize': 4}
    plotting.plot_metric(recur_day_sess_comp_ax,
                         session_paired_grps,
                         metric_fn=place.recurrence_probability,
                         groupby=(('elapsed_days_int', 'second_expt',
                                   'second_mouseID'), ),
                         plotby=('elapsed_days_int', ),
                         plot_method='box_and_line',
                         roi_filters=roi_filters,
                         markers=markers,
                         activity_kwargs={'circ_var_pcs': circ_var_pcs},
                         plot_shuffle=True,
                         shuffle_plotby=False,
                         pool_shuffle=True,
                         activity_label='Recurrence probability',
                         colors=colors,
                         rotate_labels=False,
                         plot_shuffle_as_hline=True,
                         flierprops={
                             'markersize': 2,
                             'marker': 'o'
                         },
                         box_width=0.4,
                         box_spacing=0.2,
                         return_full_dataframes=False,
                         whis='range',
                         linestyles=linestyles,
                         notch=False,
                         line_kwargs=line_kwargs)
    sns.despine(ax=recur_day_sess_comp_ax)
    recur_day_sess_comp_ax.legend(loc='upper right', fontsize=6)
    recur_day_sess_comp_ax.set_title('')
    recur_day_sess_comp_ax.set_xlabel('')
    recur_day_sess_comp_ax.set_xticklabels(['S-S', 'D-D'])
    recur_day_sess_comp_ax.set_ylim(0.0, 1.0)
    recur_day_sess_comp_ax.set_yticks([0.00, 0.25, 0.50, 0.75, 1.])

    plotting.plot_metric(stab_day_sess_comp_ax,
                         session_paired_grps,
                         metric_fn=params['stability_fn'],
                         groupby=(('elapsed_days_int', 'second_expt',
                                   'second_mouseID'), ),
                         plotby=('elapsed_days_int', ),
                         plot_method='box_and_line',
                         roi_filters=roi_filters,
                         activity_kwargs=params['stability_kwargs'],
                         plot_shuffle=True,
                         shuffle_plotby=False,
                         pool_shuffle=True,
                         activity_label=params['stability_label'],
                         colors=colors,
                         rotate_labels=False,
                         filter_fn=filter_fn,
                         markers=markers,
                         filter_columns=filter_columns,
                         plot_shuffle_as_hline=True,
                         flierprops={
                             'markersize': 2,
                             'marker': 'o'
                         },
                         box_width=0.4,
                         box_spacing=0.2,
                         return_full_dataframes=False,
                         whis='range',
                         linestyles=linestyles,
                         notch=False,
                         line_kwargs=line_kwargs)
    sns.despine(ax=stab_day_sess_comp_ax)
    stab_day_sess_comp_ax.legend(loc='upper right', fontsize=6)
    stab_day_sess_comp_ax.set_title('')
    stab_day_sess_comp_ax.set_xlabel('')
    stab_day_sess_comp_ax.set_ylabel(params['stability_label'])
    stab_day_sess_comp_ax.set_xticklabels(['S-S', 'D-D'])
    stab_day_sess_comp_ax.set_ylim(params['stab_day_ses_ylim'])
    stab_day_sess_comp_ax.set_yticks(params['stab_day_ses_yticks'])

    misc.save_figure(fig, filename, save_dir=save_dir)

    plt.close('all')
def main():
    all_grps = df.loadExptGrps('GOL')

    WT_expt_grp = all_grps['WT_place_set']
    Df_expt_grp = all_grps['Df_place_set']
    expt_grps = [WT_expt_grp, Df_expt_grp]

    fig, axs = plt.subplots(4,
                            3,
                            figsize=(8.5, 11),
                            gridspec_kw={
                                'wspace': 0.4,
                                'hspace': 0.3
                            })

    #
    # Velocity
    #
    plotting.plot_metric(axs[0, 0],
                         expt_grps,
                         metric_fn=lab.ExperimentGroup.velocity_dataframe,
                         groupby=[['expt'], ['mouseID']],
                         plotby=None,
                         plot_method='swarm',
                         activity_kwargs=None,
                         filter_fn=lambda df: df['condition'] == 'A',
                         filter_columns=['condition'],
                         activity_label='Velocity (cm/s)',
                         colors=colors,
                         plot_bar=True,
                         edgecolor='k',
                         linewidth=0.5,
                         return_full_dataframes=False)
    plotting.plot_metric(axs[0, 1],
                         expt_grps,
                         metric_fn=lab.ExperimentGroup.velocity_dataframe,
                         groupby=[['expt'], ['session_in_day', 'mouseID']],
                         plotby=['session_in_day'],
                         plot_method='swarm',
                         activity_kwargs=None,
                         rotate_labels=False,
                         filter_fn=lambda df: df['condition'] == 'A',
                         filter_columns=['condition'],
                         activity_label='Velocity (cm/s)',
                         colors=colors,
                         plot_bar=True,
                         edgecolor='k',
                         linewidth=0.5)

    #
    # Lap Rate
    #
    plotting.plot_metric(axs[1, 0],
                         expt_grps,
                         metric_fn=lab.ExperimentGroup.number_of_laps,
                         groupby=[['expt'], ['mouseID']],
                         plotby=None,
                         plot_method='swarm',
                         activity_kwargs={'rate': True},
                         filter_fn=lambda df: df['condition'] == 'A',
                         filter_columns=['condition'],
                         activity_label='Lap rate (1/min)',
                         colors=colors,
                         agg_fn=[np.sum, np.mean],
                         plot_bar=True,
                         edgecolor='k',
                         linewidth=0.5,
                         return_full_dataframes=False)
    plotting.plot_metric(axs[1, 1],
                         expt_grps,
                         metric_fn=lab.ExperimentGroup.number_of_laps,
                         groupby=[['expt'], ['session_in_day', 'mouseID']],
                         plotby=['session_in_day'],
                         plot_method='swarm',
                         activity_kwargs={'rate': True},
                         rotate_labels=False,
                         filter_fn=lambda df: df['condition'] == 'A',
                         filter_columns=['condition'],
                         activity_label='Lap rate (1/min)',
                         colors=colors,
                         agg_fn=[np.sum, np.mean],
                         plot_bar=True,
                         edgecolor='k',
                         linewidth=0.5)

    #
    # Lick Rate
    #
    plotting.plot_metric(axs[2, 0],
                         expt_grps,
                         metric_fn=lab.ExperimentGroup.behavior_dataframe,
                         groupby=[['trial'], ['mouseID']],
                         plotby=None,
                         plot_method='swarm',
                         activity_kwargs={
                             'key': 'licking',
                             'rate': True
                         },
                         filter_fn=lambda df: df['condition'] == 'A',
                         filter_columns=['condition'],
                         activity_label='Lick rate (Hz)',
                         colors=colors,
                         agg_fn=[np.sum, np.mean],
                         plot_bar=True,
                         edgecolor='k',
                         linewidth=0.5,
                         return_full_dataframes=False)
    plotting.plot_metric(axs[2, 1],
                         expt_grps,
                         metric_fn=lab.ExperimentGroup.behavior_dataframe,
                         groupby=[['trial'], ['session_in_day', 'mouseID']],
                         plotby=['session_in_day'],
                         plot_method='swarm',
                         roi_filters=None,
                         activity_kwargs={
                             'key': 'licking',
                             'rate': True
                         },
                         filter_fn=lambda df: df['condition'] == 'A',
                         filter_columns=['condition'],
                         activity_label='Lick rate (Hz)',
                         colors=colors,
                         agg_fn=[np.sum, np.mean],
                         rotate_labels=False,
                         plot_bar=True,
                         edgecolor='k',
                         linewidth=0.5)

    for ax in axs[3, :]:
        ax.set_visible(False)
    for ax in axs[:, 2]:
        ax.set_visible(False)

    sns.despine(fig=fig)

    for ax in axs.flat:
        ax.set_title('')
    for ax in axs[:, 0]:
        ax.set_xlabel('')
        ax.tick_params(bottom=False, labelbottom=False, length=3, pad=2)

    for ax in axs[:, 1]:
        ax.set_xlabel('')
        ax.set_ylabel('')
        ax.set_xticklabels(['1', '2', '3'])
        ax.tick_params(length=3, pad=2)
    for ax in axs[:3, :2].flat:
        ax.get_legend().set_visible(False)

    plotting.stackedText(axs[0, 0], [WT_label, Df_label],
                         colors=colors,
                         loc=2,
                         size=10)

    axs[2, 1].set_xlabel('Session in day')

    for ax in axs[0, :]:
        ax.set_ylim(0, 13)
    for ax in axs[1, :]:
        ax.set_ylim(0, 2.0)
        ax.set_yticks([0, 0.5, 1.0, 1.5, 2.0])
    for ax in axs[2, :]:
        ax.set_ylim(0, 3)

    misc.save_figure(fig, filename, save_dir=save_dir)

    plt.close('all')
def main():

    all_grps = df.loadExptGrps('GOL')

    WT_expt_grp = all_grps['WT_hidden_behavior_set']
    Df_expt_grp = all_grps['Df_hidden_behavior_set']
    expt_grps = [WT_expt_grp, Df_expt_grp]

    WT_pc_expt_grp = all_grps['WT_place_set']
    Df_pc_expt_grp = all_grps['Df_place_set']
    pc_expt_grps = [WT_pc_expt_grp, Df_pc_expt_grp]

    paired_grps = [
        grp.pair('consecutive groups',
                 groupby=['condition_day_session']).pair('same group',
                                                         groupby=['condition'])
        for grp in pc_expt_grps
    ]

    behavior_fn = ra.fraction_licks_in_reward_zone
    behavior_kwargs = {}
    behavior_label = 'Fraction of licks in reward zone'

    recurrence_fn = place.recurrence_probability
    recurrence_kwargs = {'circ_var_pcs': False}
    recurrence_label = 'Recurrence probability'

    stability_fn = place.activity_centroid_shift
    stability_kwargs = {
        'activity_filter': 'active_both',
        'circ_var_pcs': False,
        'units': 'norm'
    }
    stability_label = 'Centroid shift\n(fraction of belt)'
    stability_save_label = 'cent_shift'

    fig = plt.figure(figsize=(8.5, 11))
    gs = plt.GridSpec(2,
                      3,
                      left=0.1,
                      bottom=0.5,
                      right=0.89,
                      top=0.9,
                      wspace=0.4)
    behav_ax = fig.add_subplot(gs[0, 0])
    recurrence_ax = fig.add_subplot(gs[0, 1])
    stability_ax = fig.add_subplot(gs[0, 2])
    behav_z_ax = fig.add_subplot(gs[1, 0])
    recurrence_z_ax = fig.add_subplot(gs[1, 1])
    stability_z_ax = fig.add_subplot(gs[1, 2])

    lick_data = [
        behavior_fn(expt_grp, **behavior_kwargs) for expt_grp in expt_grps
    ]
    plotting.plot_dataframe(
        behav_ax,
        lick_data,
        labels=labels,
        groupby=[['expt'], ['mouseID', 'condition', 'condition_day_session']],
        plotby=['condition'],
        plot_method='line',
        activity_label=behavior_label,
        colors=colors,
        label_groupby=False,
        markers=markers,
        markersize=5,
        linestyles=linestyles)
    behav_ax.set_ylim(0, 0.4)
    behav_ax.set_yticks([0, .1, .2, .3, .4])
    behav_ax.set_xlabel('')
    behav_ax.set_title('')
    behav_ax.set_xticklabels(
        [r'$\mathrm{I}$', r'$\mathrm{II}$', r'$\mathrm{III}$'])

    recur_data, recur_shuffles = [], []
    for expt_grp in paired_grps:
        rd, rs = recurrence_fn(expt_grp, **recurrence_kwargs)
        recur_data.append(rd)
        recur_shuffles.append(rs)
    plotting.plot_dataframe(recurrence_ax,
                            recur_data,
                            recur_shuffles,
                            labels=labels,
                            groupby=[[
                                'second_mouse', 'second_condition',
                                'second_condition_day_session'
                            ]],
                            plotby=['second_condition'],
                            plot_method='line',
                            activity_label=recurrence_label,
                            colors=colors,
                            label_groupby=False,
                            plot_shuffle=True,
                            shuffle_plotby=False,
                            pool_shuffle=True,
                            markers=markers,
                            markersize=5,
                            linestyles=linestyles)
    recurrence_ax.set_ylim(0.0, 0.6)
    recurrence_ax.set_yticks([0.0, 0.2, 0.4, 0.6])
    recurrence_ax.set_xlabel('')
    recurrence_ax.set_title('')
    recurrence_ax.set_xticklabels(
        [r'$\mathrm{I}$', r'$\mathrm{II}$', r'$\mathrm{III}$'])

    stability_data, stability_shuffles = [], []
    for expt_grp in paired_grps:
        sd, ss = stability_fn(expt_grp, **stability_kwargs)
        stability_data.append(sd)
        stability_shuffles.append(ss)
    plotting.plot_dataframe(stability_ax,
                            stability_data,
                            stability_shuffles,
                            labels=labels,
                            groupby=[[
                                'second_mouse', 'second_condition',
                                'second_condition_day_session'
                            ]],
                            plotby=['second_condition'],
                            plot_method='line',
                            activity_label=stability_label,
                            colors=colors,
                            label_groupby=False,
                            plot_shuffle=True,
                            shuffle_plotby=False,
                            pool_shuffle=True,
                            markers=markers,
                            markersize=5,
                            linestyles=linestyles)
    stability_ax.set_ylim(0.15, 0.25)
    stability_ax.set_yticks([0.15, 0.17, 0.19, 0.21, 0.23, 0.25])
    stability_ax.set_xlabel('')
    stability_ax.set_title('')
    stability_ax.set_xticklabels(
        [r'$\mathrm{I}$', r'$\mathrm{II}$', r'$\mathrm{III}$'])

    lick_z_data = z_score_value(lick_data)
    plotting.plot_dataframe(behav_z_ax,
                            lick_z_data,
                            labels=labels,
                            plotby=['condition'],
                            plot_method='line',
                            activity_label=behavior_label + '\n(z-score)',
                            colors=colors,
                            label_groupby=False,
                            markers=markers,
                            markersize=5,
                            linestyles=linestyles)
    behav_z_ax.set_ylim(-1, 1)
    behav_z_ax.set_yticks([-1, -0.5, 0, 0.5, 1])
    behav_z_ax.set_xlabel('')
    behav_z_ax.set_title('')
    behav_z_ax.set_xticklabels(
        [r'$\mathrm{I}$', r'$\mathrm{II}$', r'$\mathrm{III}$'])

    recur_z_data = z_score_value(recur_data)
    plotting.plot_dataframe(recurrence_z_ax,
                            recur_z_data,
                            labels=labels,
                            plotby=['second_condition'],
                            plot_method='line',
                            activity_label=recurrence_label + '\n(z-score)',
                            colors=colors,
                            label_groupby=False,
                            markers=markers,
                            markersize=5,
                            linestyles=linestyles)
    recurrence_z_ax.set_ylim(-1, 1)
    recurrence_z_ax.set_yticks([-1, -0.5, 0, 0.5, 1])
    recurrence_z_ax.set_xlabel('')
    recurrence_z_ax.set_title('')
    recurrence_z_ax.set_xticklabels(
        [r'$\mathrm{I}$', r'$\mathrm{II}$', r'$\mathrm{III}$'])

    stability_z_data = z_score_value(stability_data, invert=True)
    plotting.plot_dataframe(stability_z_ax,
                            stability_z_data,
                            labels=labels,
                            plotby=['second_condition'],
                            plot_method='line',
                            activity_label=stability_label +
                            '\n(-1 * z-score)',
                            colors=colors,
                            label_groupby=False,
                            markers=markers,
                            markersize=5,
                            linestyles=linestyles)
    stability_z_ax.set_ylim(-1, 1)
    stability_z_ax.set_yticks([-1, -0.5, 0, 0.5, 1])
    stability_z_ax.set_xlabel('')
    stability_z_ax.set_title('')
    stability_z_ax.set_xticklabels(
        [r'$\mathrm{I}$', r'$\mathrm{II}$', r'$\mathrm{III}$'])

    sns.despine(fig)

    misc.save_figure(fig, filename, save_dir=save_dir)

    plt.close('all')
def main():

    fig = plt.figure(figsize=(8.5, 11))

    gs = plt.GridSpec(
        3, 2, left=0.1, bottom=0.4, right=0.5, top=0.9, hspace=0.4, wspace=0.4)
    shift_schematic_ax = fig.add_subplot(gs[0, 1])
    model_enrich_by_time_ax = fig.add_subplot(gs[1, 0])
    model_final_enrich_ax = fig.add_subplot(gs[1, 1])
    model_swap_compare_enrich_ax = fig.add_subplot(gs[2, 0])
    model_WT_swap_final_enrich_ax = fig.add_subplot(gs[2, 1])

    #
    # Shift schematic
    #

    mu = 1
    k = 2.
    pos1 = 1.5

    def vms(x, mu=1, k=2.):
        return np.exp(k * np.cos(x - mu)) / (2 * np.pi * i0(k))

    xr = np.linspace(-np.pi, np.pi, 1000)

    shift_schematic_ax.plot(xr, [vms(x) for x in xr], color='k')
    shift_schematic_ax.axvline(pos1, color='r', ls='--')
    shift_schematic_ax.plot(
        [mu, mu], [0, vms(mu, mu=mu, k=k)], ls=':', color='0.3')
    sns.despine(ax=shift_schematic_ax, top=True, right=True)
    shift_schematic_ax.tick_params(labelleft=True, left=True, direction='out')
    shift_schematic_ax.set_xticks([-np.pi, -np.pi / 2, 0, np.pi / 2, np.pi])
    shift_schematic_ax.set_xticklabels(['-0.50', '-0.25', '0', '0.25', '0.50'])
    shift_schematic_ax.set_xlim(-np.pi, np.pi)
    shift_schematic_ax.set_yticks([0, 0.2, 0.4, 0.6])
    shift_schematic_ax.set_ylim(0, 0.6)
    shift_schematic_ax.set_xlabel('Distance from reward (fraction of belt)')
    shift_schematic_ax.set_ylabel(
        'New place field centroid\nprobability density')

    #
    # Distance to reward
    #

    m = pickle.load(open(simulations_path))

    WT_enrich = emp.calc_enrichment(m['WT_no_swap_pos'], m['WT_no_swap_masks'])
    flat_enrich = emp.calc_enrichment(m['Df_no_swap_pos'], m['Df_no_swap_masks'])

    emp.plot_enrichment(
        model_enrich_by_time_ax, WT_enrich, WT_color, title='', rad=False)
    emp.plot_enrichment(
        model_enrich_by_time_ax, flat_enrich, flat_color, title='', rad=False)

    model_enrich_by_time_ax.set_xlabel("Iteration ('session' #)")
    plotting.stackedText(
        model_enrich_by_time_ax, [WT_label, flat_label], colors=colors, loc=2,
        size=10)

    #
    # Calc final distributions
    #

    WT_no_swap_final_dist = emp.calc_final_distributions(
        m['WT_no_swap_pos'], m['WT_no_swap_masks'])
    flat_no_swap_final_dist = emp.calc_final_distributions(
        m['Df_no_swap_pos'], m['Df_no_swap_masks'])

    WT_swap_recur_final_dist = emp.calc_final_distributions(
        m['WT_swap_recur_pos'], m['WT_swap_recur_masks'])
    flat_swap_recur_final_dist = emp.calc_final_distributions(
        m['Df_swap_recur_pos'], m['Df_swap_recur_masks'])

    WT_swap_shift_b_final_dist = emp.calc_final_distributions(
        m['WT_swap_shift_b_pos'], m['WT_swap_shift_b_masks'])
    flat_swap_shift_b_final_dist = emp.calc_final_distributions(
        m['Df_swap_shift_b_pos'], m['Df_swap_shift_b_masks'])

    WT_swap_shift_k_final_dist = emp.calc_final_distributions(
        m['WT_swap_shift_k_pos'], m['WT_swap_shift_k_masks'])
    flat_swap_shift_k_final_dist = emp.calc_final_distributions(
        m['Df_swap_shift_k_pos'], m['Df_swap_shift_k_masks'])

    #
    # Final distribution
    #

    emp.plot_final_distributions(
        model_final_enrich_ax,
        [WT_no_swap_final_dist, flat_no_swap_final_dist],
        colors, labels=labels, title='', rad=False)

    plotting.stackedText(
        model_final_enrich_ax, [WT_label, flat_label], colors=colors, loc=2,
        size=10)
    #
    # Compare enrichment
    #

    WT_bars = [
        WT_no_swap_final_dist, WT_swap_recur_final_dist,
        WT_swap_shift_k_final_dist, WT_swap_shift_b_final_dist]
    flat_bars = [
        flat_no_swap_final_dist, flat_swap_recur_final_dist,
        flat_swap_shift_k_final_dist, flat_swap_shift_b_final_dist]

    WT_bars = [np.pi / 2 - np.abs(bar) for bar in WT_bars]
    flat_bars = [np.pi / 2 - np.abs(bar) for bar in flat_bars]

    plotting.grouped_bar(
        model_swap_compare_enrich_ax, [WT_bars, flat_bars],
        [WT_label, flat_label],
        ['No\nswap', 'Swap\n' + r'$P_{recur}$',
         'Swap\nvariance', 'Swap\noffset'], [WT_color, flat_color],
        error_bars=None)

    sns.despine(ax=model_swap_compare_enrich_ax)
    model_swap_compare_enrich_ax.tick_params(length=3, pad=2, direction='out')
    model_swap_compare_enrich_ax.set_ylabel('Final enrichment (rad)')
    model_swap_compare_enrich_ax.get_legend().set_visible(False)
    model_swap_compare_enrich_ax.set_ylim(0, 0.08 * 2 * np.pi)
    y_ticks = np.array(['0', '0.02', '0.04', '0.06', '0.08'])
    model_swap_compare_enrich_ax.set_yticks(
        y_ticks.astype('float') * 2 * np.pi)
    model_swap_compare_enrich_ax.set_yticklabels(y_ticks)
    plotting.stackedText(
        model_swap_compare_enrich_ax, [WT_label, flat_label], colors=colors,
        loc=2, size=10)

    #
    # WT swap final distributions
    #

    hist_colors = iter(sns.color_palette())

    emp.plot_final_distributions(
        model_WT_swap_final_enrich_ax,
        [WT_no_swap_final_dist,
         WT_swap_recur_final_dist, WT_swap_shift_k_final_dist,
         WT_swap_shift_b_final_dist], colors=hist_colors,
        labels=['No swap', r'Swap $P_{recur}$',
                'Swap shift variance', 'Swap shift offset'], rad=False)
    model_WT_swap_final_enrich_ax.legend(
        loc='lower right', fontsize=5, frameon=False)

    misc.save_figure(fig, filename, save_dir=save_dir)

    plt.close('all')
def main():
    all_grps = df.loadExptGrps('GOL')
    expts = lab.ExperimentSet(os.path.join(df.metadata_path,
                                           'expt_metadata.xml'),
                              behaviorDataPath=os.path.join(
                                  df.data_path, 'behavior'),
                              dataPath=os.path.join(df.data_path, 'imaging'))

    WT_expt_grp = all_grps['WT_place_set']
    Df_expt_grp = all_grps['Df_place_set']
    expt_grps = [WT_expt_grp, Df_expt_grp]

    WT_label = WT_expt_grp.label()
    Df_label = Df_expt_grp.label()

    fig = plt.figure(figsize=(8.5, 11))

    gs2 = plt.GridSpec(48, 1, right=.6)
    wt_trace_ax = fig.add_subplot(gs2[12:16, 0])
    wt_position_ax = fig.add_subplot(gs2[16:18, 0])
    df_trace_ax = fig.add_subplot(gs2[18:22, 0])
    df_position_ax = fig.add_subplot(gs2[22:24, 0])

    gs3 = plt.GridSpec(2, 2, hspace=0.3, left=.62, bottom=0.5, top=0.7)
    wt_transients_ax = fig.add_subplot(gs3[0, 0], polar=True)
    df_transients_ax = fig.add_subplot(gs3[1, 0], polar=True)
    wt_vector_ax = fig.add_subplot(gs3[0, 1], polar=True)
    df_vector_ax = fig.add_subplot(gs3[1, 1], polar=True)

    gs4 = plt.GridSpec(1, 4, top=0.48, bottom=0.35, wspace=0.3)
    pf_fraction_ax = fig.add_subplot(gs4[0, 0])
    pf_per_cell_ax = fig.add_subplot(gs4[0, 1])
    pf_width_ax = fig.add_subplot(gs4[0, 2])
    circ_var_ax = fig.add_subplot(gs4[0, 3])

    pf_fraction_inset_ax = fig.add_axes([0.22, 0.36, 0.04, 0.06])
    pf_width_inset_ax = fig.add_axes([0.63, 0.41, 0.04, 0.06])
    circ_var_inset_ax = fig.add_axes([0.84, 0.36, 0.04, 0.06])

    #
    # PC Examples
    #

    wt_expt = expts.grabExpt('jz121', '2015-02-21-16h06m30s')
    df_expt = expts.grabExpt('jz098', '2014-11-08-16h14m02s')
    pc_expt_grp = place.pcExperimentGroup([wt_expt, df_expt],
                                          imaging_label='soma')
    wt_id = '0422-0349'
    df_id = '0074-0339'
    wt_idx = wt_expt.roi_ids().index(wt_id)
    df_idx = df_expt.roi_ids().index(df_id)

    wt_imaging_data = wt_expt.imagingData(channel='Ch2',
                                          label='soma',
                                          dFOverF='from_file')
    df_imaging_data = df_expt.imagingData(channel='Ch2',
                                          label='soma',
                                          dFOverF='from_file')
    wt_transients = wt_expt.transientsData(threshold=95,
                                           channel='Ch2',
                                           label='soma')
    df_transients = df_expt.transientsData(threshold=95,
                                           channel='Ch2',
                                           label='soma')

    place.plotImagingData(roi_tSeries=wt_imaging_data[wt_idx, :, 0],
                          ax=wt_trace_ax,
                          roi_transients=wt_transients[wt_idx][0],
                          position=None,
                          imaging_interval=wt_expt.frame_period(),
                          placeField=None,
                          xlabel_visible=False,
                          ylabel_visible=True,
                          right_label=True,
                          placeFieldColor=None,
                          title='',
                          rasterized=False,
                          color='.4',
                          transients_color=WT_color)
    sns.despine(ax=wt_trace_ax, top=True, left=False, bottom=True, right=True)
    wt_trace_ax.set_ylabel(WT_label, rotation='horizontal', ha='right')
    wt_trace_ax.tick_params(bottom=False, labelbottom=False)
    wt_trace_ax.tick_params(axis='y', direction='in', length=3, pad=3)
    wt_trace_ax.spines['left'].set_linewidth(1)
    wt_trace_ax.spines['left'].set_position(('outward', 5))

    place.plotImagingData(roi_tSeries=df_imaging_data[df_idx, :, 0],
                          ax=df_trace_ax,
                          roi_transients=df_transients[df_idx][0],
                          position=None,
                          imaging_interval=df_expt.frame_period(),
                          placeField=None,
                          xlabel_visible=False,
                          ylabel_visible=True,
                          right_label=True,
                          placeFieldColor=None,
                          title='',
                          rasterized=False,
                          color='.4',
                          transients_color=Df_color)
    sns.despine(ax=df_trace_ax, top=True, left=False, bottom=True, right=True)
    df_trace_ax.set_ylabel(Df_label, rotation='horizontal', ha='right')
    df_trace_ax.tick_params(bottom=False, labelbottom=False)
    df_trace_ax.tick_params(axis='y', direction='in', length=3, pad=3)
    df_trace_ax.spines['left'].set_linewidth(1)
    df_trace_ax.spines['left'].set_position(('outward', 5))

    y_min = min(wt_trace_ax.get_ylim()[0], df_trace_ax.get_ylim()[0])
    y_max = max(wt_trace_ax.get_ylim()[1], df_trace_ax.get_ylim()[1])
    wt_trace_ax.set_ylim(y_min, y_max)
    df_trace_ax.set_ylim(y_min, y_max)
    wt_trace_ax.set_yticks([0, y_max])
    df_trace_ax.set_yticks([0, y_max])
    wt_trace_ax.set_yticklabels(['0', '{:0.1f}'.format(y_max)])
    df_trace_ax.set_yticklabels(['0', '{:0.1f}'.format(y_max)])
    wt_trace_ax.set_xlim(0, 600)
    df_trace_ax.set_xlim(0, 600)

    place.plotPosition(wt_expt.find('trial'),
                       ax=wt_position_ax,
                       rasterized=False,
                       position_kwargs={'color': 'k'})
    sns.despine(ax=wt_position_ax,
                top=True,
                left=True,
                bottom=True,
                right=True)
    wt_position_ax.set_ylabel('')
    wt_position_ax.set_xlabel('')
    wt_position_ax.tick_params(left=False,
                               bottom=False,
                               top=False,
                               right=False,
                               labelleft=False,
                               labelbottom=False,
                               labelright=False,
                               labeltop=False)
    plotting.add_scalebar(wt_position_ax,
                          matchx=False,
                          matchy=False,
                          hidex=False,
                          hidey=False,
                          sizex=60,
                          labelx='1 min',
                          bar_thickness=.02,
                          pad=0,
                          loc=4)
    wt_position_ax.set_yticks([0, 1])
    wt_position_ax.set_ylim(0, 1)

    place.plotPosition(df_expt.find('trial'),
                       ax=df_position_ax,
                       rasterized=False,
                       position_kwargs={'color': 'k'})
    sns.despine(ax=df_position_ax, top=True, bottom=True, right=True)
    df_position_ax.tick_params(bottom=False,
                               top=False,
                               right=False,
                               labelbottom=False,
                               labelright=False,
                               labeltop=False,
                               direction='in',
                               length=3,
                               pad=3)
    df_position_ax.spines['left'].set_linewidth(1)
    df_position_ax.spines['left'].set_position(('outward', 5))
    df_position_ax.set_ylabel('Position')
    df_position_ax.set_xlabel('')
    df_position_ax.set_yticks([0, 1])
    df_position_ax.set_ylim(0, 1)

    trans_kwargs = {
        'color': WT_color,
        'marker': 'o',
        'linestyle': 'None',
        'markersize': 3
    }
    wt_pf = [pc_expt_grp.pfs_n()[wt_expt][wt_idx]]
    place.plotPosition(wt_expt.find('trial'),
                       ax=wt_transients_ax,
                       polar=True,
                       placeFields=wt_pf,
                       placeFieldColors=[WT_color],
                       trans_roi_filter=lambda roi: roi.id == wt_id,
                       rasterized=False,
                       running_trans_only=True,
                       demixed=False,
                       position_kwargs={'color': '0.5'},
                       trans_kwargs=trans_kwargs)
    wt_transients_ax.set_xlabel('')
    wt_transients_ax.set_ylabel('')
    wt_transients_ax.set_rticks([])
    prep_polar_ax(wt_transients_ax)

    trans_kwargs['color'] = Df_color
    df_pf = [pc_expt_grp.pfs_n()[df_expt][df_idx]]
    place.plotPosition(df_expt.find('trial'),
                       ax=df_transients_ax,
                       polar=True,
                       placeFields=df_pf,
                       placeFieldColors=[Df_color],
                       trans_roi_filter=lambda roi: roi.id == df_id,
                       rasterized=False,
                       running_trans_only=True,
                       demixed=False,
                       position_kwargs={'color': '0.5'},
                       trans_kwargs=trans_kwargs)
    df_transients_ax.set_xlabel('')
    df_transients_ax.set_ylabel('')
    df_transients_ax.set_rticks([])
    prep_polar_ax(df_transients_ax)

    place.plotTransientVectors(place.pcExperimentGroup([wt_expt],
                                                       imaging_label='soma'),
                               wt_idx,
                               wt_vector_ax,
                               mean_zorder=99,
                               color=WT_color,
                               mean_color='g')
    place.plotTransientVectors(place.pcExperimentGroup([df_expt],
                                                       imaging_label='soma'),
                               df_idx,
                               df_vector_ax,
                               mean_zorder=99,
                               color=Df_color,
                               mean_color='g')

    #
    # Stats
    #

    groupby = [['expt'], ['mouseID']]

    plotting.plot_metric(pf_fraction_ax,
                         expt_grps,
                         metric_fn=place.place_cell_percentage,
                         groupby=None,
                         plotby=None,
                         colorby=None,
                         plot_method='cdf',
                         roi_filters=roi_filters,
                         activity_kwargs=None,
                         colors=colors,
                         activity_label='Place cell fraction',
                         rotate_labels=False,
                         return_full_dataframes=False,
                         linestyles=linestyles)
    pf_fraction_ax.legend(loc='upper left', fontsize=6)
    pf_fraction_ax.set_title('')
    pf_fraction_ax.set_ylabel('Cumulative fraction')
    pf_fraction_ax.set_xticks([0, .2, .4, .6, .8])
    pf_fraction_ax.set_xlim(0, .8)
    pf_fraction_ax.spines['left'].set_linewidth(1)
    pf_fraction_ax.spines['bottom'].set_linewidth(1)

    plotting.plot_metric(pf_fraction_inset_ax,
                         expt_grps,
                         metric_fn=place.place_cell_percentage,
                         groupby=groupby,
                         plotby=None,
                         colorby=None,
                         plot_method='swarm',
                         roi_filters=roi_filters,
                         activity_kwargs=None,
                         colors=colors,
                         activity_label='Place cell fraction',
                         rotate_labels=False,
                         linewidth=0.2,
                         edgecolor='gray')
    pf_fraction_inset_ax.set_title('')
    pf_fraction_inset_ax.set_ylabel('')
    pf_fraction_inset_ax.set_xlabel('')
    pf_fraction_inset_ax.set_yticks([0, 0.5])
    pf_fraction_inset_ax.set_ylim([0, 0.5])
    pf_fraction_inset_ax.get_legend().set_visible(False)
    sns.despine(ax=pf_fraction_inset_ax)
    pf_fraction_inset_ax.tick_params(bottom=False, labelbottom=False)
    pf_fraction_inset_ax.spines['left'].set_linewidth(1)
    pf_fraction_inset_ax.spines['bottom'].set_linewidth(1)
    pf_fraction_inset_ax.set_xlim(-0.6, 0.6)

    n_pf_kwargs = {'per_mouse_fractions': True, 'max_n_place_fields': 3}
    plotting.plot_metric(pf_per_cell_ax,
                         expt_grps,
                         metric_fn=place.n_place_fields,
                         groupby=None,
                         plotby=['number'],
                         plot_method='swarm',
                         roi_filters=roi_filters,
                         activity_kwargs=n_pf_kwargs,
                         colors=colors,
                         activity_label='Fraction of place cells',
                         rotate_labels=False,
                         plot_bar=True,
                         edgecolor='k',
                         linewidth=0.5,
                         size=3)
    sns.despine(ax=pf_per_cell_ax)
    pf_per_cell_ax.set_title('')
    pf_per_cell_ax.set_xlabel('Place fields per cell')
    pf_per_cell_ax.set_ylabel('Fraction of place cells')
    pf_per_cell_ax.set_xticklabels(['1', '2', '3+'])
    pf_per_cell_ax.set_yticks([0, 0.2, 0.4, 0.6, 0.8, 1])

    plotting.plot_metric(pf_width_ax,
                         expt_grps,
                         metric_fn=place.place_field_width,
                         groupby=[['roi_id', 'expt']],
                         plotby=None,
                         plot_method='hist',
                         roi_filters=roi_filters,
                         activity_kwargs=None,
                         activity_label='Place field width (cm)',
                         normed=True,
                         plot_mean=True,
                         bins=20,
                         range=(0, 120),
                         colors=colors,
                         rotate_labels=False,
                         filled=False,
                         mean_kwargs={'ls': ':'},
                         return_full_dataframes=False,
                         linestyles=linestyles)
    pf_width_ax.set_title('')
    pf_width_ax.legend(loc='lower right', fontsize=6)
    pf_width_ax.set_xticks([0, 40, 80, 120])
    pf_width_ax.set_yticks([0, 0.02, 0.04, 0.06, 0.08])
    pf_width_ax.set_ylim(0, 0.08)
    pf_width_ax.set_ylabel('Normalized density')
    pf_width_ax.spines['left'].set_linewidth(1)
    pf_width_ax.spines['bottom'].set_linewidth(1)

    plotting.plot_metric(pf_width_inset_ax,
                         expt_grps,
                         metric_fn=place.place_field_width,
                         groupby=[['roi_id', 'expt'], ['expt'], ['mouseID']],
                         plotby=None,
                         plot_method='swarm',
                         roi_filters=roi_filters,
                         activity_kwargs=None,
                         activity_label='Place field width (cm)',
                         colors=colors,
                         rotate_labels=False,
                         linewidth=0.2,
                         edgecolor='gray')
    pf_width_inset_ax.set_title('')
    pf_width_inset_ax.set_ylabel('')
    pf_width_inset_ax.set_xlabel('')
    pf_width_inset_ax.get_legend().set_visible(False)
    sns.despine(ax=pf_width_inset_ax)
    pf_width_inset_ax.tick_params(bottom=False, labelbottom=False)
    pf_width_inset_ax.set_ylim(25, 40)
    pf_width_inset_ax.set_yticks([25, 40])
    pf_width_inset_ax.spines['left'].set_linewidth(1)
    pf_width_inset_ax.spines['bottom'].set_linewidth(1)
    pf_width_inset_ax.set_xlim(-0.6, 0.6)

    plotting.plot_metric(circ_var_ax,
                         expt_grps,
                         metric_fn=place.circular_variance,
                         groupby=[['roi_id', 'expt']],
                         plotby=None,
                         plot_method='cdf',
                         roi_filters=roi_filters,
                         activity_kwargs=None,
                         activity_label='Circular variance',
                         colors=colors,
                         rotate_labels=False,
                         return_full_dataframes=False,
                         linestyles=linestyles)
    circ_var_ax.set_title('')
    circ_var_ax.legend(loc='upper left', fontsize=6)
    circ_var_ax.set_ylabel('Cumulative fraction')
    circ_var_ax.set_xlim(-0.1, 1)
    circ_var_ax.set_xticks([0, 0.5, 1])
    circ_var_ax.spines['left'].set_linewidth(1)
    circ_var_ax.spines['bottom'].set_linewidth(1)

    plotting.plot_metric(circ_var_inset_ax,
                         expt_grps,
                         metric_fn=place.circular_variance,
                         groupby=groupby,
                         plotby=None,
                         plot_method='swarm',
                         roi_filters=roi_filters,
                         activity_kwargs=None,
                         activity_label='Circular variance',
                         colors=colors,
                         rotate_labels=False,
                         linewidth=0.2,
                         edgecolor='gray')
    circ_var_inset_ax.set_title('')
    circ_var_inset_ax.set_ylabel('')
    circ_var_inset_ax.set_xlabel('')
    circ_var_inset_ax.get_legend().set_visible(False)
    sns.despine(ax=circ_var_inset_ax)
    circ_var_inset_ax.tick_params(bottom=False, labelbottom=False)
    circ_var_inset_ax.set_ylim(0, 0.6)
    circ_var_inset_ax.set_yticks([0, 0.6])
    circ_var_inset_ax.spines['left'].set_linewidth(1)
    circ_var_inset_ax.spines['bottom'].set_linewidth(1)
    circ_var_inset_ax.set_xlim(-0.6, 0.6)

    misc.save_figure(fig, filename, save_dir=save_dir)

    plt.close('all')
예제 #14
0
def main():
    all_expt_grps = df.loadExptGrps('GOL')

    WT_expt_grp_hidden = all_expt_grps['WT_place_set']
    Df_expt_grp_hidden = all_expt_grps['Df_place_set']
    expt_grps = [WT_expt_grp_hidden, Df_expt_grp_hidden]

    paired_grps = [grp.pair(
        'consecutive groups', groupby=['condition_day']) for grp in expt_grps]

    fig = plt.figure(figsize=(8.5, 11))

    gs1 = plt.GridSpec(
        3, 8, top=0.9, bottom=0.7, left=0.1, right=0.9, wspace=0.4)
    ax2 = fig.add_subplot(gs1[:, :2])
    pf_fraction_ax = fig.add_subplot(gs1[:, 2:4])
    circ_var_ax = fig.add_subplot(gs1[:, 4:6])

    trans_ax1 = fig.add_subplot(gs1[0, 6], polar=True)
    trans_ax2 = fig.add_subplot(gs1[0, 7], polar=True)
    trans_ax3 = fig.add_subplot(gs1[1, 6], polar=True)
    trans_ax4 = fig.add_subplot(gs1[1, 7], polar=True)
    trans_ax5 = fig.add_subplot(gs1[2, 6], polar=True)
    trans_ax6 = fig.add_subplot(gs1[2, 7], polar=True)
    trans_axs = [
        trans_ax1, trans_ax2, trans_ax3, trans_ax4, trans_ax5, trans_ax6]

    #
    # Stability by distance to fabric transitions
    #

    filter_fn = lambda df: df['second_condition_day_session'] == 'B_0_0'
    filter_columns = ['second_condition_day_session']

    label_order = ['before', 'middle', 'after']

    data_to_plot = [[], [], []]
    all_data, shuffles = [], []
    for expt_grp, roi_filter in zip(paired_grps, roi_filters):

        fabric_map = {expt: expt.belt().fabric_transitions(
            units='normalized') for expt in expt_grp}

        def norm_diff(n1, n2):
            d = n1 - n2
            d = d + 1.0 if d < -0.5 else d
            d = d - 1.0 if d >= 0.5 else d
            return d

        def closest_transition(row):
            expt = row['first_expt']
            centroid = complex_to_norm(row['first_centroid'])
            positions = fabric_map[expt]['position']
            distances = [norm_diff(centroid, t) for t in positions]
            row['closest'] = distances[np.argmin(np.abs(distances))]
            return row

        data, shuffle = place.activity_centroid_shift(
            expt_grp, roi_filter=roi_filter, activity_filter='active_both',
            circ_var_pcs=False, units='norm', shuffle=True)

        plotting.prepare_dataframe(data, include_columns=filter_columns)
        data = data[filter_fn(data)]

        plotting.prepare_dataframe(shuffle, include_columns=filter_columns)

        data = data.apply(closest_transition, axis=1)
        shuffle = shuffle.apply(closest_transition, axis=1)

        def categorize(row):
            if row['closest'] < 0 and -1 / 9. < row['closest']:
                row['category'] = 'before'
            elif row['closest'] > 0 and 1 / 9. > row['closest']:
                row['category'] = 'after'
            else:
                row['category'] = 'middle'
            return row

        data = data.apply(categorize, axis=1)
        shuffle = shuffle.apply(categorize, axis=1)

        groupby = [
            ['second_condition_day_session', 'second_mouse', 'category']]

        for gb in groupby:
            plotting.prepare_dataframe(data, include_columns=gb)
            plotting.prepare_dataframe(shuffle, include_columns=gb)
            data = data.groupby(gb, as_index=False).mean()
            shuffle = shuffle.groupby(gb, as_index=False).mean()

        for category, group in data.groupby(['category']):
            idx = label_order.index(category)
            data_to_plot[idx].append(group['value'])

        shuffles.append(shuffle)
        all_data.append(data)

    shuffle_df = pd.concat(shuffles, ignore_index=True)
    for category, group in shuffle_df.groupby(['category']):
        idx = label_order.index(category)
        data_to_plot[idx].append(group['value'])

    plotting.grouped_bar(
        ax2, data_to_plot, condition_labels=label_order,
        cluster_labels=df.labels + ('shuffle',),
        bar_colors=sns.color_palette('deep')[3:], scatter_points=False,
        scatterbar_colors=None, jitter_x=False, loc='best', error_bars='sem')
    sns.despine(ax=ax2)
    ax2.set_yticks([0, 0.1, 0.2, 0.3])
    ax2.set_ylabel('Centroid shift (fraction of belt)')

    #
    # Burlap belt
    #

    expts = lab.ExperimentSet(
        os.path.join(df.metadata_path, 'expt_metadata.xml'),
        behaviorDataPath=os.path.join(df.data_path, 'behavior'),
        dataPath=os.path.join(df.data_path, 'imaging'))

    burlap_expt_grp = lab.classes.pcExperimentGroup.from_json(
        cue_free_json, expts, imaging_label=df.IMAGING_LABEL, label='cue-free')

    acute_grps = df.loadExptGrps('RF')

    WT_expt_grp_acute = acute_grps['WT_place_set'].unpair()
    WT_expt_grp_acute.label('cue-rich')

    burlap_colors = ('k', '0.9')
    example_expt = expts.grabExptByPath('/jz128/TSeries-07262015-burlap-000')
    cv = place.circular_variance_p(burlap_expt_grp)
    cv = cv[cv['expt'] == example_expt]
    cv = cv.sort_values(by=['value'])

    trans_kwargs = {
        'color': '0.9', 'marker': 'o', 'linestyle': 'None',
        'markersize': 3}
    for ax, (idx, row) in zip(trans_axs, cv.iloc[:6].iterrows()):
        expt = row['expt']
        roi_idx = expt.rois().index(row['roi'])
        pf = None
        place.plotPosition(
            expt.find('trial'), ax=ax, polar=True,
            placeFields=pf, placeFieldColors=['0.9'],
            trans_roi_filter=lambda roi: roi.id == expt.roi_ids()[roi_idx],
            rasterized=False, running_trans_only=True, demixed=False,
            position_kwargs={'color': '0.5'}, trans_kwargs=trans_kwargs)
        ax.set_xlabel('')
        ax.set_ylabel('')
        ax.set_rticks([])
        prep_polar_ax(ax)
    for ax in trans_axs[:-1]:
        ax.set_xticklabels(['', '', '', ''])

    activity_kwargs = {'circ_var': True}

    plotting.plot_metric(
        pf_fraction_ax, [WT_expt_grp_acute, burlap_expt_grp],
        metric_fn=place.place_cell_percentage,
        groupby=None, plotby=None, colorby=None, plot_method='swarm',
        roi_filters=[WT_filter, WT_filter], activity_kwargs=activity_kwargs,
        colors=burlap_colors, activity_label='Place cell fraction',
        rotate_labels=False, plot_bar=True)
    pf_fraction_ax.set_title('')
    pf_fraction_ax.set_xlabel('')
    sns.despine(ax=pf_fraction_ax)
    pf_fraction_ax.set_yticks([0.0, 0.1, 0.2, 0.3, 0.4])

    plotting.plot_metric(
        circ_var_ax, [WT_expt_grp_acute, burlap_expt_grp],
        metric_fn=place.circular_variance,
        groupby=[['roi_id', 'expt']], plotby=None, plot_method='cdf',
        roi_filters=[WT_filter, WT_filter], activity_kwargs=None,
        activity_label='Circular variance', colors=burlap_colors,
        rotate_labels=False)
    circ_var_ax.set_title('')
    circ_var_ax.get_legend().set_visible(False)
    circ_var_ax.set_xticks([0, 0.5, 1])

    save_figure(fig, filename, save_dir=save_dir)

    plt.close('all')
예제 #15
0
def main():

    fig = plt.figure(figsize=(8.5, 11))

    gs1 = plt.GridSpec(2,
                       2,
                       left=0.1,
                       right=0.7,
                       top=0.9,
                       bottom=0.5,
                       hspace=0.4,
                       wspace=0.4)
    WT_enrich_ax = fig.add_subplot(gs1[0, 0])
    Df_enrich_ax = fig.add_subplot(gs1[0, 1])
    WT_final_dist_ax = fig.add_subplot(gs1[1, 0])
    Df_final_dist_ax = fig.add_subplot(gs1[1, 1])

    simulations_path_A = os.path.join(
        df.data_path, 'enrichment_model',
        'WT_Df_enrichment_model_simulation_A.pkl')
    simulations_path_B = os.path.join(
        df.data_path, 'enrichment_model',
        'WT_Df_enrichment_model_simulation_B.pkl')
    simulations_path_C = os.path.join(
        df.data_path, 'enrichment_model',
        'WT_Df_enrichment_model_simulation_C.pkl')

    m_A = pickle.load(open(simulations_path_A))
    m_B = pickle.load(open(simulations_path_B))
    m_C = pickle.load(open(simulations_path_C))

    WT_colors = sns.light_palette(WT_color, 7)[2::2]
    Df_colors = sns.light_palette(Df_color, 7)[2::2]

    condition_labels = [
        r'Condition $\mathrm{I}$', r'Condition $\mathrm{II}$',
        r'Condition $\mathrm{III}$'
    ]

    WT_final_dists, Df_final_dists = [], []

    for m, WT_c, Df_c in zip((m_A, m_B, m_C), WT_colors, Df_colors):

        WT_enrich = emp.calc_enrichment(m['WT_no_swap_pos'],
                                        m['WT_no_swap_masks'])
        Df_enrich = emp.calc_enrichment(m['Df_no_swap_pos'],
                                        m['Df_no_swap_masks'])

        WT_final_dists.append(
            emp.calc_final_distributions(m['WT_no_swap_pos'],
                                         m['WT_no_swap_masks']))
        Df_final_dists.append(
            emp.calc_final_distributions(m['Df_no_swap_pos'],
                                         m['Df_no_swap_masks']))

        emp.plot_enrichment(WT_enrich_ax, WT_enrich, WT_c, title='', rad=False)
        emp.plot_enrichment(Df_enrich_ax, Df_enrich, Df_c, title='', rad=False)

    WT_enrich_ax.set_xlabel("Iteration ('session' #)")
    Df_enrich_ax.set_xlabel("Iteration ('session' #)")
    plotting.stackedText(WT_enrich_ax,
                         condition_labels,
                         colors=WT_colors,
                         loc=2,
                         size=8)
    plotting.stackedText(Df_enrich_ax,
                         condition_labels,
                         colors=Df_colors,
                         loc=2,
                         size=8)

    emp.plot_final_distributions(WT_final_dist_ax,
                                 WT_final_dists,
                                 WT_colors,
                                 labels=condition_labels,
                                 title='',
                                 rad=False)
    emp.plot_final_distributions(Df_final_dist_ax,
                                 Df_final_dists,
                                 Df_colors,
                                 labels=condition_labels,
                                 title='',
                                 rad=False)

    WT_final_dist_ax.set_xlabel('Distance from reward\n(fraction of belt)')
    Df_final_dist_ax.set_xlabel('Distance from reward\n(fraction of belt)')
    plotting.stackedText(WT_final_dist_ax,
                         condition_labels,
                         colors=WT_colors,
                         loc=2,
                         size=8)
    plotting.stackedText(Df_final_dist_ax,
                         condition_labels,
                         colors=Df_colors,
                         loc=2,
                         size=8)
    WT_final_dist_ax.set_yticks([0, 0.1, 0.2, 0.3])
    Df_final_dist_ax.set_yticks([0, 0.1, 0.2, 0.3])

    save_figure(fig, filename, save_dir=save_dir)
def main():

    all_grps = df.loadExptGrps('GOL')

    WT_expt_grp = all_grps['WT_hidden_behavior_set']
    Df_expt_grp = all_grps['Df_hidden_behavior_set']

    behavior_fn = ra.fraction_licks_in_reward_zone
    behavior_kwargs = {}
    activity_label = 'Fraction of licks in reward zone'

    WT_colors = sns.light_palette(WT_color, 8)[::-1]
    Df_colors = sns.light_palette(Df_color, 7)[::-1]
    markers = ('o', 'v', '^', 'D', '*', 's')

    fig, axs = plt.subplots(4, 2, figsize=(8.5, 11))

    sns.despine(fig)

    wt_ax = axs[0, 0]
    df_ax = axs[0, 1]

    for ax in list(axs.flat)[2:]:
        ax.set_visible(False)

    wt_expt_grps = [
        WT_expt_grp.subGroup(list(expts['expt']), label=mouse)
        for mouse, expts in WT_expt_grp.dataframe(
            WT_expt_grp, include_columns=['mouseID']).groupby('mouseID')
    ]
    df_expt_grps = [
        Df_expt_grp.subGroup(list(expts['expt']), label=mouse)
        for mouse, expts in Df_expt_grp.dataframe(
            Df_expt_grp, include_columns=['mouseID']).groupby('mouseID')
    ]

    plotting.plot_metric(wt_ax,
                         wt_expt_grps,
                         metric_fn=behavior_fn,
                         activity_kwargs=behavior_kwargs,
                         groupby=[['expt'], ['condition_day']],
                         plotby=['condition_day'],
                         plot_method='line',
                         ms=5,
                         activity_label=activity_label,
                         colors=WT_colors,
                         markers=markers,
                         label_every_n=1,
                         label_groupby=False,
                         rotate_labels=False)
    wt_ax.set_xlabel('Day in Condition')
    wt_ax.set_title(WT_expt_grp.label())
    wt_ax.set_yticks([0, 0.1, 0.2, 0.3, 0.4, 0.5])
    wt_ax.set_xticklabels(['1', '2', '3', '1', '2', '3', '1', '2', '3'])
    label_conditions(wt_ax)
    wt_ax.get_legend().set_visible(False)
    wt_ax.tick_params(length=3, pad=2)

    plotting.plot_metric(df_ax,
                         df_expt_grps,
                         metric_fn=behavior_fn,
                         activity_kwargs=behavior_kwargs,
                         groupby=[['expt'], ['condition_day']],
                         plotby=['condition_day'],
                         plot_method='line',
                         ms=5,
                         activity_label=activity_label,
                         colors=Df_colors,
                         markers=markers,
                         label_every_n=1,
                         label_groupby=False,
                         rotate_labels=False)
    df_ax.set_xlabel('Day in Condition')
    df_ax.set_title(Df_expt_grp.label())
    df_ax.set_yticks([0, 0.1, 0.2, 0.3, 0.4, 0.5])
    df_ax.set_xticklabels(['1', '2', '3', '1', '2', '3', '1', '2', '3'])
    label_conditions(df_ax)
    df_ax.get_legend().set_visible(False)
    df_ax.tick_params(length=3, pad=2)

    misc.save_figure(fig, filename, save_dir=save_dir)
예제 #17
0
def main():
    Df_raw_data, Df_data = emd.load_data('df',
                                         root=os.path.join(
                                             df.data_path, 'enrichment_model'))

    Df_params = pickle.load(open(params_path))

    fig = plt.figure(figsize=(8.5, 11))
    gs = plt.GridSpec(3,
                      2,
                      left=0.1,
                      bottom=0.5,
                      right=0.5,
                      top=0.9,
                      wspace=0.3,
                      hspace=0.3)
    Df_recur_ax = fig.add_subplot(gs[0, 0])
    Df_shift_ax = fig.add_subplot(gs[0, 1])
    shift_compare_ax = fig.add_subplot(gs[1, 0])
    var_compare_ax = fig.add_subplot(gs[1, 1])
    enrichment_ax = fig.add_subplot(gs[2, 0])
    final_enrichment_ax = fig.add_subplot(gs[2, 1])

    #
    # Recurrence by position
    #

    recur_x_vals = np.linspace(-np.pi, np.pi, 1000)

    Df_recur_data = emd.recurrence_by_position(Df_data, method='cv')
    Df_recur_knots = np.linspace(-np.pi, np.pi,
                                 Df_params['position_recurrence']['n_knots'])
    Df_recur_spline = splines.CyclicSpline(Df_recur_knots)
    Df_recur_N = Df_recur_spline.design_matrix(recur_x_vals)

    Df_recur_fit = splines.prob(Df_params['position_recurrence']['theta'],
                                Df_recur_N)

    Df_recur_boots_fits = [
        splines.prob(boot, Df_recur_N)
        for boot in Df_params['position_recurrence']['boots_theta']
    ]
    Df_recur_ci_up_fit = np.percentile(Df_recur_boots_fits, 95, axis=0)
    Df_recur_ci_low_fit = np.percentile(Df_recur_boots_fits, 5, axis=0)

    Df_recur_ax.plot(recur_x_vals, Df_recur_fit, color=Df_color)
    Df_recur_ax.fill_between(recur_x_vals,
                             Df_recur_ci_low_fit,
                             Df_recur_ci_up_fit,
                             facecolor=Df_color,
                             alpha=0.5)
    sns.regplot(Df_recur_data[:, 0],
                Df_recur_data[:, 1],
                ax=Df_recur_ax,
                color=Df_color,
                y_jitter=0.2,
                fit_reg=False,
                scatter_kws={'s': 1},
                marker=Df_marker)
    Df_recur_ax.set_xlim(-np.pi, np.pi)
    Df_recur_ax.set_xticks([-np.pi, -np.pi / 2, 0, np.pi / 2, np.pi])
    Df_recur_ax.set_xticklabels(['-0.50', '-0.25', '0', '0.25', '0.50'])
    Df_recur_ax.set_ylim(-0.3, 1.3)
    Df_recur_ax.set_yticks([0, 0.5, 1])
    Df_recur_ax.tick_params(length=3, pad=1, top=False)
    Df_recur_ax.set_xlabel('Initial distance from reward (fraction of belt)')
    Df_recur_ax.set_ylabel('Place cell recurrence probability')
    Df_recur_ax.set_title('')
    Df_recur_ax_2 = Df_recur_ax.twinx()
    Df_recur_ax_2.set_ylim(-0.3, 1.3)
    Df_recur_ax_2.set_yticks([0, 1])
    Df_recur_ax_2.set_yticklabels(['non-recur', 'recur'])
    Df_recur_ax_2.tick_params(length=3, pad=1, top=False)

    #
    # Place field stability
    #

    shift_x_vals = np.linspace(-np.pi, np.pi, 1000)

    Df_shift_knots = Df_params['position_stability']['all_pairs']['knots']
    Df_shift_spline = splines.CyclicSpline(Df_shift_knots)
    Df_shift_N = Df_shift_spline.design_matrix(shift_x_vals)
    Df_shift_theta_b = Df_params['position_stability']['all_pairs']['theta_b']
    Df_shift_b_fit = np.dot(Df_shift_N, Df_shift_theta_b)
    Df_shift_theta_k = Df_params['position_stability']['all_pairs']['theta_k']
    Df_shift_k_fit = splines.get_k(Df_shift_theta_k, Df_shift_N)
    Df_shift_fit_var = 1. / Df_shift_k_fit

    Df_shift_data = emd.paired_activity_centroid_distance_to_reward(Df_data)
    Df_shift_data = Df_shift_data.dropna()
    Df_shifts = Df_shift_data['second'] - Df_shift_data['first']
    Df_shifts[Df_shifts < -np.pi] += 2 * np.pi
    Df_shifts[Df_shifts >= np.pi] -= 2 * np.pi

    Df_shift_ax.plot(shift_x_vals, Df_shift_b_fit, color=Df_color)
    Df_shift_ax.fill_between(shift_x_vals,
                             Df_shift_b_fit - Df_shift_fit_var,
                             Df_shift_b_fit + Df_shift_fit_var,
                             facecolor=Df_color,
                             alpha=0.5)
    sns.regplot(Df_shift_data['first'],
                Df_shifts,
                ax=Df_shift_ax,
                color=Df_color,
                fit_reg=False,
                scatter_kws={'s': 1},
                marker=Df_marker)

    Df_shift_ax.axvline(ls='--', color='0.4', lw=0.5)
    Df_shift_ax.axhline(ls='--', color='0.4', lw=0.5)
    Df_shift_ax.plot([-np.pi, np.pi], [np.pi, -np.pi], color='g', ls=':', lw=2)
    Df_shift_ax.tick_params(length=3, pad=1, top=False)
    Df_shift_ax.set_xlabel('Initial distance from reward (fraction of belt)')
    Df_shift_ax.set_ylabel(r'$\Delta$ position (fraction of belt)')
    Df_shift_ax.set_xlim(-np.pi, np.pi)
    Df_shift_ax.set_xticks([-np.pi, -np.pi / 2, 0, np.pi / 2, np.pi])
    Df_shift_ax.set_xticklabels(['-0.50', '-0.25', '0', '0.25', '0.50'])
    Df_shift_ax.set_ylim(-np.pi, np.pi)
    Df_shift_ax.set_yticks([-np.pi, -np.pi / 2, 0, np.pi / 2, np.pi])
    Df_shift_ax.set_yticklabels(['-0.50', '-0.25', '0', '0.25', '0.50'])
    Df_shift_ax.set_title('')

    #
    # Stability by distance to reward
    #

    shift_x_vals = np.linspace(-np.pi, np.pi, 1000)

    Df_shift_knots = Df_params['position_stability']['all_pairs']['knots']
    Df_shift_spline = splines.CyclicSpline(Df_shift_knots)
    Df_shift_N = Df_shift_spline.design_matrix(shift_x_vals)

    Df_shift_theta_b = Df_params['position_stability']['all_pairs']['theta_b']
    Df_shift_b_fit = np.dot(Df_shift_N, Df_shift_theta_b)
    Df_shift_boots_b_fit = [
        np.dot(Df_shift_N, boot) for boot in Df_params['position_stability']
        ['all_pairs']['boots_theta_b']
    ]
    Df_shift_b_ci_up_fit = np.percentile(Df_shift_boots_b_fit, 95, axis=0)
    Df_shift_b_ci_low_fit = np.percentile(Df_shift_boots_b_fit, 5, axis=0)

    shift_compare_ax.plot(shift_x_vals, Df_shift_b_fit, color=Df_color)
    shift_compare_ax.fill_between(shift_x_vals,
                                  Df_shift_b_ci_low_fit,
                                  Df_shift_b_ci_up_fit,
                                  facecolor=Df_color,
                                  alpha=0.5)

    shift_compare_ax.axvline(ls='--', color='0.4', lw=0.5)
    shift_compare_ax.axhline(ls='--', color='0.4', lw=0.5)
    shift_compare_ax.tick_params(length=3, pad=1, top=False)
    shift_compare_ax.set_xlim(-np.pi, np.pi)
    shift_compare_ax.set_xticks([-np.pi, -np.pi / 2, 0, np.pi / 2, np.pi])
    shift_compare_ax.set_xticklabels(['-0.50', '-0.25', '0', '0.25', '0.50'])
    shift_compare_ax.set_ylim(-0.10 * 2 * np.pi, 0.10 * 2 * np.pi)
    y_ticks = np.array(['-0.10', '-0.05', '0', '0.05', '0.10'])
    shift_compare_ax.set_yticks(y_ticks.astype('float') * 2 * np.pi)
    shift_compare_ax.set_yticklabels(y_ticks)
    shift_compare_ax.set_xlabel(
        'Initial distance from reward (fraction of belt)')
    shift_compare_ax.set_ylabel(r'$\Delta$ position (fraction of belt)')

    Df_shift_theta_k = Df_params['position_stability']['all_pairs']['theta_k']
    Df_shift_k_fit = splines.get_k(Df_shift_theta_k, Df_shift_N)
    Df_shift_boots_k_fit = [
        splines.get_k(boot, Df_shift_N) for boot in
        Df_params['position_stability']['all_pairs']['boots_theta_k']
    ]
    Df_shift_k_ci_up_fit = np.percentile(Df_shift_boots_k_fit, 95, axis=0)
    Df_shift_k_ci_low_fit = np.percentile(Df_shift_boots_k_fit, 5, axis=0)

    var_compare_ax.plot(shift_x_vals, 1. / Df_shift_k_fit, color=Df_color)
    var_compare_ax.fill_between(shift_x_vals,
                                1. / Df_shift_k_ci_low_fit,
                                1. / Df_shift_k_ci_up_fit,
                                facecolor=Df_color,
                                alpha=0.5)

    var_compare_ax.axvline(ls='--', color='0.4', lw=0.5)
    var_compare_ax.tick_params(length=3, pad=1, top=False)
    var_compare_ax.set_xlim(-np.pi, np.pi)
    var_compare_ax.set_xticks([-np.pi, -np.pi / 2, 0, np.pi / 2, np.pi])
    var_compare_ax.set_xticklabels(['-0.50', '-0.25', '0', '0.25', '0.50'])
    y_ticks = np.array(['0.005', '0.010', '0.015', '0.020'])
    var_compare_ax.set_yticks(y_ticks.astype('float') * (2 * np.pi)**2)
    var_compare_ax.set_yticklabels(y_ticks)
    var_compare_ax.set_xlabel(
        'Initial distance from reward (fraction of belt)')
    var_compare_ax.set_ylabel(r'$\Delta$ position variance')

    #
    # Enrichment
    #

    m = pickle.load(open(simulations_path))
    Df_enrich = emp.calc_enrichment(m['Df_no_swap_pos'], m['Df_no_swap_masks'])

    emp.plot_enrichment(enrichment_ax,
                        Df_enrich,
                        Df_color,
                        title='',
                        rad=False)
    enrichment_ax.set_xlabel("Iteration ('session' #)")

    #
    # Final Enrichment
    #

    Df_no_swap_final_dist = emp.calc_final_distributions(
        m['Df_no_swap_pos'], m['Df_no_swap_masks'])

    emp.plot_final_distributions(final_enrichment_ax, [Df_no_swap_final_dist],
                                 [Df_color],
                                 title='',
                                 rad=False)

    misc.save_figure(fig, filename, save_dir=save_dir)

    plt.close('all')
예제 #18
0
def main():
    all_grps = df.loadExptGrps('GOL')

    expt_grps = [all_grps['WT_place_set'], all_grps['Df_place_set']]

    pc_filters = [expt_grp.pcs_filter(roi_filter=roi_filter) for
                  expt_grp, roi_filter in zip(expt_grps, roi_filters)]

    fig = plt.figure(figsize=(8.5, 11))

    gs1 = plt.GridSpec(
        3, 3, left=0.1, right=0.9, top=0.9, bottom=0.3, wspace=0.5, hspace=0.4)
    sensitivity_cdf_ax = fig.add_subplot(gs1[1, 0])
    sensitivity_bar_ax = fig.add_axes([0.24, 0.535, 0.05, 0.07])

    specificity_cdf_ax = fig.add_subplot(gs1[1, 1])
    specificity_bar_ax = fig.add_axes([0.43, 0.60, 0.05, 0.07])

    width_cdf_ax = fig.add_subplot(gs1[2, 0])
    width_bar_ax = fig.add_axes([0.24, 0.315, 0.05, 0.07])

    sparsity_cdf_ax = fig.add_subplot(gs1[2, 1])
    sparsity_bar_ax = fig.add_axes([0.54, 0.315, 0.05, 0.07])

    is_ever_pc_fraction_ax = fig.add_subplot(gs1[0, 0])

    pc_fraction_ax = fig.add_subplot(gs1[0, 1])
    pc_fraction_bar_ax = fig.add_axes([0.54, 0.755, 0.05, 0.07])

    cdf_axs = [
        sensitivity_cdf_ax, specificity_cdf_ax, width_cdf_ax, sparsity_cdf_ax,
        pc_fraction_ax]
    bar_axs = [
        sensitivity_bar_ax, specificity_bar_ax, width_bar_ax, sparsity_bar_ax,
        pc_fraction_bar_ax]

    sensitivity_range = (0, 1)
    plotting.plot_metric(
        sensitivity_cdf_ax, expt_grps, metric_fn=place.sensitivity,
        groupby=[['roi_id', 'expt']], plotby=None, plot_method='cdf',
        roi_filters=pc_filters, activity_kwargs=None,
        activity_label='Transient sensitivity', colors=colors,
        rotate_labels=False, linestyles=linestyles)
    sensitivity_cdf_ax.set_xlim(sensitivity_range)
    plotting.plot_metric(
        sensitivity_bar_ax, expt_grps, metric_fn=place.sensitivity,
        groupby=[['roi_id', 'expt'],
                 ['roi_id', 'uniqueLocationKey', 'mouseID'], ['mouseID']],
        plotby=None, plot_method='grouped_bar',
        roi_filters=pc_filters, activity_kwargs=None,
        activity_label='Transient sensitivity', colors=colors)
    sensitivity_bar_ax.set_ylim(sensitivity_range)
    sensitivity_bar_ax.set_yticks(sensitivity_range)

    plotting.stackedText(
        sensitivity_cdf_ax, [WT_label, Df_label], colors=colors, loc=2,
        size=10)

    specificity_range = (0, 1)
    plotting.plot_metric(
        specificity_cdf_ax, expt_grps, metric_fn=place.specificity,
        groupby=[['roi_id', 'expt']], plotby=None, plot_method='cdf',
        roi_filters=pc_filters, activity_kwargs=None,
        activity_label='Transient specificity', colors=colors,
        rotate_labels=False, linestyles=linestyles)
    specificity_cdf_ax.set_xlim(specificity_range)
    plotting.plot_metric(
        specificity_bar_ax, expt_grps, metric_fn=place.specificity,
        groupby=[['roi_id', 'expt'],
                 ['roi_id', 'uniqueLocationKey', 'mouseID'], ['mouseID']],
        plotby=None, plot_method='grouped_bar',
        roi_filters=pc_filters, activity_kwargs=None,
        activity_label='Transient specificity', colors=colors)
    specificity_bar_ax.set_ylim(specificity_range)
    specificity_bar_ax.set_yticks(specificity_range)

    width_range = (0, 100)
    plotting.plot_metric(
        width_cdf_ax, expt_grps, metric_fn=place.place_field_width,
        groupby=None, plotby=None, plot_method='cdf',
        roi_filters=pc_filters, activity_kwargs=None,
        activity_label='Place field width (cm)', colors=colors,
        rotate_labels=False, linestyles=linestyles)
    width_cdf_ax.set_xlim(width_range)
    plotting.plot_metric(
        width_bar_ax, expt_grps, metric_fn=place.place_field_width,
        groupby=[['roi_id', 'uniqueLocationKey', 'mouseID'], ['mouseID']],
        plotby=None, plot_method='grouped_bar',
        roi_filters=pc_filters, activity_kwargs=None,
        activity_label='Place field width', colors=colors)
    width_bar_ax.set_ylim(width_range)
    width_bar_ax.set_yticks(width_range)

    sparsity_range = (0, 1)
    plotting.plot_metric(
        sparsity_cdf_ax, expt_grps, metric_fn=place.sparsity,
        groupby=[['roi_id', 'expt']], plotby=None, plot_method='cdf',
        roi_filters=pc_filters, activity_kwargs=None,
        activity_label='Single-cell sparsity', colors=colors,
        rotate_labels=False, linestyles=linestyles)
    sparsity_cdf_ax.set_xlim(sparsity_range)
    plotting.plot_metric(
        sparsity_bar_ax, expt_grps, metric_fn=place.sparsity,
        groupby=[['roi_id', 'expt'],
                 ['roi_id', 'uniqueLocationKey', 'mouseID'], ['mouseID']],
        plotby=None, plot_method='grouped_bar',
        roi_filters=pc_filters, activity_kwargs=None,
        activity_label='Single-cell sparsity', colors=colors)
    sparsity_bar_ax.set_ylim(sparsity_range)
    sparsity_bar_ax.set_yticks(sparsity_range)

    fraction_ses_pc_range = (0, 0.5)
    plotting.plot_metric(
        pc_fraction_ax, expt_grps, metric_fn=lab.ExperimentGroup.filtered_rois,
        groupby=(('mouseID', 'uniqueLocationKey', 'roi_id'),), plotby=None,
        colorby=None, plot_method='cdf', roi_filters=pc_filters,
        activity_kwargs=[
            {'include_roi_filter': roi_filter} for roi_filter in roi_filters],
        colors=colors, activity_label='Fraction of sessions a place cell',
        rotate_labels=False, linestyles=linestyles)
    pc_fraction_ax.tick_params(length=3, pad=2)
    pc_fraction_ax.get_legend().set_visible(False)
    pc_fraction_ax.set_title('')
    pc_fraction_ax.set_xticks([0, .2, .4, .6, .8])
    pc_fraction_ax.set_xlim(0, .8)
    pc_fraction_ax.spines['left'].set_linewidth(1)
    pc_fraction_ax.spines['bottom'].set_linewidth(1)

    plotting.plot_metric(
        pc_fraction_bar_ax, expt_grps,
        metric_fn=lab.ExperimentGroup.filtered_rois,
        groupby=[['roi_id', 'uniqueLocationKey', 'mouseID'], ['mouseID']],
        plotby=None, plot_method='grouped_bar',
        roi_filters=pc_filters, activity_kwargs=[
            {'include_roi_filter': roi_filter} for roi_filter in roi_filters],
        activity_label='Fraction of sessions a place cell', colors=colors)
    pc_fraction_bar_ax.set_ylim(fraction_ses_pc_range)
    pc_fraction_bar_ax.set_yticks(fraction_ses_pc_range)

    place.is_ever_place_cell(
        expt_grps, roi_filters=roi_filters, ax=is_ever_pc_fraction_ax,
        colors=colors, filter_fn=lambda df: df['session_number'] < 15,
        filter_columns=['session_number'],
        groupby=[['mouseID', 'session_number']])
    is_ever_pc_fraction_ax.get_legend().set_visible(False)
    is_ever_pc_fraction_ax.tick_params(length=3, pad=2)
    is_ever_pc_fraction_ax.set_xlabel('Session number')
    is_ever_pc_fraction_ax.set_title('Lifetime place coding')
    is_ever_pc_fraction_ax.set_xticklabels([
        '1', '', '3', '', '5', '', '7', '', '9', '', '11', '', '13', '', '15'])

    for ax in cdf_axs + bar_axs:
        ax.tick_params(length=3, pad=2)
    for ax in cdf_axs:
        ax.set_title('')
    for ax in bar_axs:
        ax.set_xlabel('')
        ax.set_ylabel('')
        ax.set_title('')
        ax.set_xticks([])
        ax.tick_params(bottom=False, labelbottom=False, length=3, pad=2)
        ax.get_legend().set_visible(False)

    sns.despine(fig)

    misc.save_figure(fig, filename, save_dir=save_dir)

    plt.close('all')