def main():
    hidden_grps = df.loadExptGrps('GOL')

    WT_expt_grp_hidden = hidden_grps['WT_place_set']
    Df_expt_grp_hidden = hidden_grps['Df_place_set']
    expt_grps_hidden = [WT_expt_grp_hidden, Df_expt_grp_hidden]

    acute_grps = df.loadExptGrps('RF')

    WT_expt_grp_acute = acute_grps['WT_place_set'].unpair()
    Df_expt_grp_acute = acute_grps['Df_place_set'].unpair()
    expt_grps_acute = [WT_expt_grp_acute, Df_expt_grp_acute]

    WT_label = WT_expt_grp_hidden.label()
    Df_label = Df_expt_grp_hidden.label()
    labels = [WT_label, Df_label]

    fig = plt.figure(figsize=(8.5, 11))

    gs1 = plt.GridSpec(1, 1, top=0.9, bottom=0.7, left=0.1, right=0.20)
    across_ctx_ax = fig.add_subplot(gs1[0, 0])

    gs2 = plt.GridSpec(3, 1, top=0.9, bottom=0.7, left=0.25, right=0.35)
    wt_pie_ax = fig.add_subplot(gs2[0, 0])
    df_pie_ax = fig.add_subplot(gs2[1, 0])
    shuffle_pie_ax = fig.add_subplot(gs2[2, 0])
    pie_axs = (wt_pie_ax, df_pie_ax, shuffle_pie_ax)

    gs3 = plt.GridSpec(1, 1, top=0.9, bottom=0.7, left=0.4, right=0.5)
    cue_cell_bar_ax = fig.add_subplot(gs3[0, 0])

    gs5 = plt.GridSpec(1, 1, top=0.5, bottom=0.3, left=0.1, right=0.3)
    acute_stability_ax = fig.add_subplot(gs5[0, 0])

    acute_stability_inset_ax = fig.add_axes([0.23, 0.32, 0.05, 0.08])

    gs6 = plt.GridSpec(1, 1, top=0.5, bottom=0.3, left=0.4, right=0.5)
    task_compare_ax = fig.add_subplot(gs6[0, 0])

    #
    # RF Compare
    #
    params = {}
    params['filename'] = filename

    params_cent_shift_pc = {}
    params_cent_shift_pc['stability_fn'] = place.activity_centroid_shift
    params_cent_shift_pc['stability_kwargs'] = {
        'activity_filter': 'pc_both',
        'circ_var_pcs': False,
        'units': 'norm',
        'shuffle': True
    }
    params_cent_shift_pc['stability_label'] = \
        'Centroid shift (fraction of belt)'

    params_cent_shift_all = {}
    params_cent_shift_all['stability_fn'] = place.activity_centroid_shift
    params_cent_shift_all['stability_kwargs'] = {
        'activity_filter': 'active_both',
        'circ_var_pcs': False,
        'units': 'norm',
        'shuffle': True
    }
    params_cent_shift_all['stability_label'] = \
        'Centroid shift (fraction of belt)'
    params_cent_shift_all['stability_inset_ylim'] = (0.15, 0.30)
    params_cent_shift_all['stability_cdf_range'] = (0.15, 0.35)
    params_cent_shift_all['stability_cdf_ticks'] = \
        (0.15, 0.20, 0.25, 0.30, 0.35)
    params_cent_shift_all['stability_compare_ylim'] = (0.15, 0.27)
    params_cent_shift_all['stability_compare_yticks'] = (0.15, 0.20, 0.25)
    params_cent_shift_all['ctx_compare_ylim'] = (0.10, 0.30)
    params_cent_shift_all['ctx_compare_yticks'] = \
        (0.10, 0.15, 0.20, 0.25, 0.30)

    params_cent_shift_cm = {}
    params_cent_shift_cm['stability_fn'] = place.activity_centroid_shift
    params_cent_shift_cm['stability_kwargs'] = {
        'activity_filter': 'active_both',
        'circ_var_pcs': False,
        'units': 'cm',
        'shuffle': True
    }
    params_cent_shift_cm['stability_label'] = 'Centroid shift (cm)'

    params_pop_vect_corr = {}
    params_pop_vect_corr['stability_fn'] = place.population_vector_correlation
    params_pop_vect_corr['stability_kwargs'] = {
        'method': 'corr',
        'activity_filter': 'pc_both',
        'min_pf_density': 0.05,
        'circ_var_pcs': False
    }
    params_pop_vect_corr['stability_label'] = 'Population vector correlation'

    params_pf_corr = {}
    params_pf_corr['stability_fn'] = place.place_field_correlation
    params_pf_corr['stability_kwargs'] = {'activity_filter': 'pc_either'}
    params_pf_corr['stability_label'] = 'Place field correlation'
    params_pf_corr['stability_inset_ylim'] = (0, 0.50)
    params_pf_corr['stability_cdf_range'] = (0.15, 0.55)
    params_pf_corr['stability_cdf_ticks'] = (0.15, 0.25, 0.35, 0.45, 0.55)
    params_pf_corr['stability_compare_ylim'] = (0.22, 0.40)
    params_pf_corr['stability_compare_yticks'] = (0.25, 0.30, 0.35, 0.40)
    params_pf_corr['hidden_ctx_compare_ylim'] = (0.22, 0.40)
    params_pf_corr['hidden_ctx_compare_yticks'] = (0.25, 0.30, 0.35, 0.40)
    params_pf_corr['ctx_compare_ylim'] = (0.22, 0.40)
    params_pf_corr['ctx_compare_yticks'] = (0.25, 0.30, 0.35, 0.40)

    params.update(params_cent_shift_all)

    day_paired_grps_acute = [
        grp.pair('consecutive groups', groupby=['day_in_df'])
        for grp in expt_grps_acute
    ]
    paired_grps_acute = day_paired_grps_acute
    paired_grps_hidden = [
        grp.pair('consecutive groups', groupby=['X_condition', 'X_day'])
        for grp in expt_grps_hidden
    ]

    filter_fn = lambda df: (df['expt_pair_label'] == 'SameAll')
    filter_columns = ['expt_pair_label']

    acute_stability = plotting.plot_metric(
        acute_stability_ax,
        paired_grps_acute,
        metric_fn=params['stability_fn'],
        groupby=[['expt_pair_label', 'second_expt']],
        plotby=None,
        plot_method='cdf',
        plot_abs=True,
        roi_filters=roi_filters,
        activity_kwargs=params['stability_kwargs'],
        plot_shuffle=True,
        shuffle_plotby=False,
        pool_shuffle=True,
        activity_label=params['stability_label'],
        colors=colors,
        rotate_labels=False,
        filter_fn=filter_fn,
        filter_columns=filter_columns,
        return_full_dataframes=False,
        linestyles=linestyles)
    acute_stability_ax.set_xlabel(params['stability_label'])
    acute_stability_ax.set_title('')
    sns.despine(ax=acute_stability_ax)
    acute_stability_ax.set_xlim(params['stability_cdf_range'])
    acute_stability_ax.set_xticks(params['stability_cdf_ticks'])
    acute_stability_ax.legend(loc='upper left', fontsize=6)

    plotting.plot_metric(acute_stability_inset_ax,
                         paired_grps_acute,
                         metric_fn=params['stability_fn'],
                         groupby=[['second_expt'], ['second_mouseID']],
                         plotby=None,
                         plot_method='swarm',
                         plot_abs=True,
                         roi_filters=roi_filters,
                         activity_kwargs=params['stability_kwargs'],
                         plot_shuffle=True,
                         shuffle_plotby=False,
                         pool_shuffle=True,
                         activity_label=params['stability_label'],
                         colors=colors,
                         rotate_labels=False,
                         filter_fn=filter_fn,
                         filter_columns=filter_columns,
                         linewidth=0.2,
                         edgecolor='gray',
                         plot_shuffle_as_hline=True)
    acute_stability_inset_ax.get_legend().set_visible(False)
    sns.despine(ax=acute_stability_inset_ax)
    acute_stability_inset_ax.set_title('')
    acute_stability_inset_ax.set_ylabel('')
    acute_stability_inset_ax.set_xlabel('')
    acute_stability_inset_ax.tick_params(bottom=False, labelbottom=False)
    acute_stability_inset_ax.set_ylim(params['stability_inset_ylim'])
    acute_stability_inset_ax.set_yticks(params['stability_inset_ylim'])

    tmp_fig = plt.figure()
    tmp_ax = tmp_fig.add_subplot(111)
    hidden_stability = plotting.plot_metric(
        tmp_ax,
        paired_grps_hidden,
        metric_fn=params['stability_fn'],
        groupby=[['expt_pair_label', 'second_expt']],
        plotby=('expt_pair_label', ),
        plot_method='line',
        plot_abs=True,
        roi_filters=roi_filters,
        activity_kwargs=params['stability_kwargs'],
        plot_shuffle=True,
        shuffle_plotby=False,
        pool_shuffle=True,
        activity_label=params['stability_label'],
        colors=colors,
        rotate_labels=False,
        filter_fn=filter_fn,
        filter_columns=filter_columns,
        return_full_dataframes=False)
    plt.close(tmp_fig)

    wt_acute = acute_stability[WT_label]['dataframe']
    wt_acute_shuffle = acute_stability[WT_label]['shuffle']
    df_acute = acute_stability[Df_label]['dataframe']
    df_acute_shuffle = acute_stability[Df_label]['shuffle']

    wt_hidden = hidden_stability[WT_label]['dataframe']
    wt_hidden_shuffle = hidden_stability[WT_label]['shuffle']
    df_hidden = hidden_stability[Df_label]['dataframe']
    df_hidden_shuffle = hidden_stability[Df_label]['shuffle']

    for dataframe in (wt_acute, wt_acute_shuffle, df_acute, df_acute_shuffle):
        dataframe['task'] = 'RF'

    for dataframe in (wt_hidden, wt_hidden_shuffle, df_hidden,
                      df_hidden_shuffle):
        dataframe['task'] = 'GOL'

    WT_data = wt_acute.append(wt_hidden, ignore_index=True)
    Df_data = df_acute.append(df_hidden, ignore_index=True)

    WT_shuffle = wt_acute_shuffle.append(wt_hidden_shuffle, ignore_index=True)
    Df_shuffle = df_acute_shuffle.append(df_hidden_shuffle, ignore_index=True)

    filter_columns = ('expt_pair_label', )
    filter_fn = lambda df: (df['expt_pair_label'] == 'SameAll')

    order_dict = {'RF': 0, 'GOL': 1}
    WT_data['order'] = WT_data['task'].map(order_dict)
    Df_data['order'] = Df_data['task'].map(order_dict)
    line_kwargs = {'markersize': 4}
    plotting.plot_dataframe(task_compare_ax, [WT_data, Df_data],
                            [WT_shuffle, Df_shuffle],
                            labels=labels,
                            activity_label='',
                            groupby=[['task', 'second_mouseID']],
                            plotby=('task', ),
                            plot_method='box_and_line',
                            colors=colors,
                            filter_fn=filter_fn,
                            filter_columns=filter_columns,
                            plot_shuffle=True,
                            shuffle_plotby=False,
                            pool_shuffle=True,
                            orderby='order',
                            notch=False,
                            plot_shuffle_as_hline=True,
                            markers=markers,
                            linestyles=linestyles,
                            line_kwargs=line_kwargs,
                            flierprops={
                                'markersize': 3,
                                'marker': 'o'
                            },
                            whis='range')
    task_compare_ax.set_title('')
    sns.despine(ax=task_compare_ax)
    task_compare_ax.set_ylim(params['stability_compare_ylim'])
    task_compare_ax.set_yticks(params['stability_compare_yticks'])
    task_compare_ax.set_xlabel('')
    task_compare_ax.set_ylabel(params['stability_label'])
    task_compare_ax.legend(loc='upper right', fontsize=6)

    #
    # Stability across transition
    #
    groupby = [['second_expt'], ['second_mouse']]
    filter_fn = lambda df: (df['X_first_condition'] == 'A') \
        & (df['X_second_condition'] == 'B')
    filter_columns = ('X_first_condition', 'X_second_condition')
    plotting.plot_metric(across_ctx_ax,
                         paired_grps_hidden,
                         metric_fn=params['stability_fn'],
                         groupby=groupby,
                         plotby=None,
                         plot_method='swarm',
                         activity_kwargs=params['stability_kwargs'],
                         plot_shuffle=True,
                         shuffle_plotby=False,
                         pool_shuffle=True,
                         colors=colors,
                         activity_label=params['stability_label'],
                         rotate_labels=False,
                         filter_fn=filter_fn,
                         filter_columns=filter_columns,
                         plot_shuffle_as_hline=True,
                         return_full_dataframes=False,
                         plot_bar=True,
                         roi_filters=roi_filters)
    sns.despine(ax=across_ctx_ax)
    across_ctx_ax.set_ylim(0.0, 0.3)
    across_ctx_ax.set_yticks([0.0, 0.1, 0.2, 0.3])
    across_ctx_ax.set_xticklabels([])
    across_ctx_ax.set_xlabel('')
    across_ctx_ax.set_title('')
    across_ctx_ax.get_legend().set_visible(False)
    plotting.stackedText(across_ctx_ax, labels, colors=colors, loc=2, size=10)

    #
    # Cue remapping
    #

    THRESHOLD = 0.05 * 2 * np.pi
    CUENESS_THRESHOLD = 0.33

    def first_cue_position(row):
        expt = row['first_expt']
        cue = row['cue']
        cues = expt.belt().cues(normalized=True)
        first_cue = cues.ix[cues['cue'] == cue]
        pos = (first_cue['start'] + first_cue['stop']) / 2
        angle = pos * 2 * np.pi
        return np.complex(np.cos(angle), np.sin(angle))

    def dotproduct(v1, v2):
        return sum((a * b) for a, b in zip(v1, v2))

    def length(v):
        return math.sqrt(dotproduct(v, v))

    def angle(v1, v2):
        return math.acos(
            np.round(dotproduct(v1, v2) / (length(v1) * length(v2)), 3))

    def distance_to_first_cue(row):
        centroid = row['second_centroid']
        pos = row['first_cue_position']
        return angle((pos.real, pos.imag), (centroid.real, centroid.imag))

    WT_copy = copy(WT_expt_grp_hidden)
    WT_copy.filterby(lambda df: ~df['X_condition'].str.contains('C'),
                     ['X_condition'])
    WT_paired = WT_copy.pair('consecutive groups',
                             groupby=['X_condition',
                                      'X_day']).pair('consecutive groups',
                                                     groupby=['X_condition'])

    Df_copy = copy(Df_expt_grp_hidden)
    Df_copy.filterby(lambda df: ~df['X_condition'].str.contains('C'),
                     ['X_condition'])
    Df_paired = Df_copy.pair('consecutive groups',
                             groupby=['X_condition',
                                      'X_day']).pair('consecutive groups',
                                                     groupby=['X_condition'])

    WT_df, WT_shuffle_df = place.cue_cell_remapping(
        WT_paired,
        roi_filter=WT_filter,
        near_threshold=THRESHOLD,
        activity_filter='active_both',
        circ_var_pcs=False,
        shuffle=True)
    Df_df, Df_shuffle_df = place.cue_cell_remapping(
        Df_paired,
        roi_filter=Df_filter,
        near_threshold=THRESHOLD,
        activity_filter='active_both',
        circ_var_pcs=False,
        shuffle=True)

    shuffle_df = pd.concat([WT_shuffle_df, Df_shuffle_df], ignore_index=True)

    cueness, cueness_fraction = [], []
    cue_n, place_n, neither_n = [], [], []

    for grp_df in (WT_df, Df_df, shuffle_df):

        grp_df['first_cue_position'] = grp_df.apply(first_cue_position, axis=1)

        grp_df['second_distance_to_first_cue_position'] = grp_df.apply(
            distance_to_first_cue, axis=1)

        grp_df['cueness'] = grp_df['second_distance_to_first_cue_position'] / \
            (grp_df['value'] + grp_df['second_distance_to_first_cue_position'])

        plotting.prepare_dataframe(grp_df, ['first_mouse'])
        cueness_fraction.append([[]])
        cue_n.append([])
        place_n.append([])
        neither_n.append([])

        for mouse, mouse_df in grp_df.groupby('first_mouse'):
            cue_n[-1].append((mouse_df['cueness'] >
                              (1 - CUENESS_THRESHOLD)).sum())
            place_n[-1].append((mouse_df['cueness'] < CUENESS_THRESHOLD).sum())
            neither_n[-1].append(mouse_df.shape[0] - cue_n[-1][-1] -
                                 place_n[-1][-1])
            cueness_fraction[-1][0].append(cue_n[-1][-1] /
                                           float(place_n[-1][-1]))
        cueness.append([grp_df['cueness']])

    cue_labels = labels + ['shuffle']

    plotting.swarm_plot(cue_cell_bar_ax,
                        cueness_fraction[:2],
                        condition_labels=labels,
                        colors=colors,
                        plot_bar=True)
    cue_cell_bar_ax.axhline(np.mean(cueness_fraction[-1][0]),
                            ls='--',
                            color='k')
    sns.despine(ax=cue_cell_bar_ax)
    cue_cell_bar_ax.set_ylim(0, 1.5)
    cue_cell_bar_ax.set_yticks([0, 0.5, 1.0, 1.5])
    cue_cell_bar_ax.set_xticklabels([])
    cue_cell_bar_ax.set_xlabel('')
    cue_cell_bar_ax.set_ylabel('Cue-to-position ratio')
    cue_cell_bar_ax.get_legend().set_visible(False)
    plotting.stackedText(cue_cell_bar_ax,
                         labels,
                         colors=colors,
                         loc=2,
                         size=10)

    WT_colors = sns.light_palette(WT_color, 7)[:-6:-2]
    Df_colors = sns.light_palette(Df_color, 7)[:-6:-2]
    shuffle_colors = sns.light_palette('k', 7)[:-6:-2]
    pie_colors = (WT_colors, Df_colors, shuffle_colors)
    pie_labels = ['cue', 'position', 'neither']
    orig_size = mpl.rcParams.get('xtick.labelsize')
    mpl.rcParams['xtick.labelsize'] = 5
    for grp_ax, grp_label, grp_cue_n, grp_place_n, grp_neither_n, p_cs in zip(
            pie_axs, cue_labels, cue_n, place_n, neither_n, pie_colors):
        grp_ax.pie([sum(grp_cue_n),
                    sum(grp_place_n),
                    sum(grp_neither_n)],
                   autopct='%1.0f%%',
                   shadow=False,
                   frame=False,
                   labels=pie_labels,
                   colors=p_cs,
                   textprops={'fontsize': 5})
        grp_ax.set_title(grp_label)
        plotting.square_axis(grp_ax)
    mpl.rcParams['xtick.labelsize'] = orig_size

    misc.save_figure(fig, params['filename'], save_dir=save_dir)

    plt.close('all')
def main():

    fig, axs = plt.subplots(
        5, 2, figsize=(8.5, 11), gridspec_kw={
            'wspace': 0.4, 'hspace': 0.35, 'right': 0.75, 'top': 0.9,
            'bottom': 0.1})

    #
    # Run the model
    #

    n_cells = 1000
    n_runs = 100
    tol = 1e-4

    WT_params = pkl.load(open(WT_params_path, 'r'))

    model_cls = em.EnrichmentModel2

    WT_model = model_cls(**WT_params)
    flat_model = WT_model.copy()
    flat_model.flatten()

    WT_model.initialize(n_cells=n_cells, flat_tol=tol)
    flat_model.initialize_like(WT_model)
    initial_mask = WT_model.mask
    initial_positions = WT_model.positions

    def run_model(model1, model2=None, initial_positions=None,
                  initial_mask=None, n_runs=100, **interp_kwargs):
        masks, positions = [], []

        model = model1.copy()

        if model2 is not None:
            model.interpolate(model2, **interp_kwargs)

        for _ in range(n_runs):
            model.initialize(
                initial_mask=initial_mask, initial_positions=initial_positions)

            model.run(8)

            masks.append(model._masks)
            positions.append(model._positions)

        return positions, masks

    WT_no_swap_pos, WT_no_swap_masks = run_model(
        WT_model, initial_positions=initial_positions,
        initial_mask=initial_mask, n_runs=n_runs)
    flat_no_swap_pos, flat_no_swap_masks = run_model(
        flat_model, initial_positions=initial_positions,
        initial_mask=initial_mask, n_runs=n_runs)

    WT_swap_on_pos, WT_swap_on_masks = run_model(
        WT_model, flat_model, initial_positions=initial_positions,
        initial_mask=initial_mask, n_runs=n_runs, on=1)
    flat_swap_on_pos, flat_swap_on_masks = run_model(
        flat_model, WT_model, initial_positions=initial_positions,
        initial_mask=initial_mask, n_runs=n_runs, on=1)

    WT_swap_recur_pos, WT_swap_recur_masks = run_model(
        WT_model, flat_model, initial_positions=initial_positions,
        initial_mask=initial_mask, n_runs=n_runs, recur=1)
    flat_swap_recur_pos, flat_swap_recur_masks = run_model(
        flat_model, WT_model, initial_positions=initial_positions,
        initial_mask=initial_mask, n_runs=n_runs, recur=1)

    WT_swap_shift_b_pos, WT_swap_shift_b_masks = run_model(
        WT_model, flat_model, initial_positions=initial_positions,
        initial_mask=initial_mask, n_runs=n_runs, shift_b=1)
    flat_swap_shift_b_pos, flat_swap_shift_b_masks = run_model(
        flat_model, WT_model, initial_positions=initial_positions,
        initial_mask=initial_mask, n_runs=n_runs, shift_b=1)

    WT_swap_shift_k_pos, WT_swap_shift_k_masks = run_model(
        WT_model, flat_model, initial_positions=initial_positions,
        initial_mask=initial_mask, n_runs=n_runs, shift_k=1)
    flat_swap_shift_k_pos, flat_swap_shift_k_masks = run_model(
        flat_model, WT_model, initial_positions=initial_positions,
        initial_mask=initial_mask, n_runs=n_runs, shift_k=1)

    #
    # Distance to reward
    #

    WT_no_swap_enrich = emp.calc_enrichment(
        WT_no_swap_pos, WT_no_swap_masks)
    flat_no_swap_enrich = emp.calc_enrichment(
        flat_no_swap_pos, flat_no_swap_masks)

    WT_swap_recur_enrich = emp.calc_enrichment(
        WT_swap_recur_pos, WT_swap_recur_masks)
    flat_swap_recur_enrich = emp.calc_enrichment(
        flat_swap_recur_pos, flat_swap_recur_masks)

    WT_swap_shift_b_enrich = emp.calc_enrichment(
        WT_swap_shift_b_pos, WT_swap_shift_b_masks)
    flat_swap_shift_b_enrich = emp.calc_enrichment(
        flat_swap_shift_b_pos, flat_swap_shift_b_masks)

    WT_swap_shift_k_enrich = emp.calc_enrichment(
        WT_swap_shift_k_pos, WT_swap_shift_k_masks)
    flat_swap_shift_k_enrich = emp.calc_enrichment(
        flat_swap_shift_k_pos, flat_swap_shift_k_masks)

    emp.plot_enrichment(
        axs[0, 0], WT_no_swap_enrich, WT_color, '', rad=False)
    emp.plot_enrichment(
        axs[0, 0], flat_no_swap_enrich, flat_color, '', rad=False)
    plotting.right_label(axs[0, 1], 'No swap')
    axs[0, 0].set_ylabel('')

    emp.plot_enrichment(
        axs[1, 0], WT_swap_recur_enrich, WT_color, '', rad=False)
    emp.plot_enrichment(
        axs[1, 0], flat_swap_recur_enrich, flat_color, '', rad=False)
    plotting.right_label(axs[1, 1], 'Swap recurrence')

    emp.plot_enrichment(
        axs[2, 0], WT_swap_shift_k_enrich, WT_color, '', rad=False)
    emp.plot_enrichment(
        axs[2, 0], flat_swap_shift_k_enrich, flat_color, '', rad=False)
    plotting.right_label(axs[2, 1], 'Swap shift variance')
    axs[2, 0].set_ylabel('')

    emp.plot_enrichment(
        axs[3, 0], WT_swap_shift_b_enrich, WT_color, '', rad=False)
    emp.plot_enrichment(
        axs[3, 0], flat_swap_shift_b_enrich, flat_color, '', rad=False)
    plotting.right_label(axs[3, 1], 'Swap shift offset')
    axs[3, 0].set_ylabel('')

    plotting.stackedText(
        axs[0, 0], [WT_label, flat_label], colors=colors, loc=2, size=10)

    #
    # Final distribution
    #

    WT_no_swap_final_dist = emp.calc_final_distributions(
        WT_no_swap_pos, WT_no_swap_masks)
    flat_no_swap_final_dist = emp.calc_final_distributions(
        flat_no_swap_pos, flat_no_swap_masks)

    WT_swap_recur_final_dist = emp.calc_final_distributions(
        WT_swap_recur_pos, WT_swap_recur_masks)
    flat_swap_recur_final_dist = emp.calc_final_distributions(
        flat_swap_recur_pos, flat_swap_recur_masks)

    WT_swap_shift_b_final_dist = emp.calc_final_distributions(
        WT_swap_shift_b_pos, WT_swap_shift_b_masks)
    flat_swap_shift_b_final_dist = emp.calc_final_distributions(
        flat_swap_shift_b_pos, flat_swap_shift_b_masks)

    WT_swap_shift_k_final_dist = emp.calc_final_distributions(
        WT_swap_shift_k_pos, WT_swap_shift_k_masks)
    flat_swap_shift_k_final_dist = emp.calc_final_distributions(
        flat_swap_shift_k_pos, flat_swap_shift_k_masks)

    emp.plot_final_distributions(
        axs[0, 1], [WT_no_swap_final_dist, flat_no_swap_final_dist], colors,
        labels=None, title='', rad=False)
    emp.plot_final_distributions(
        axs[1, 1], [WT_swap_recur_final_dist, flat_swap_recur_final_dist],
        colors, labels=None, title='', rad=False)
    emp.plot_final_distributions(
        axs[2, 1], [WT_swap_shift_k_final_dist, flat_swap_shift_k_final_dist],
        colors, labels=None, title='', rad=False)
    emp.plot_final_distributions(
        axs[3, 1], [WT_swap_shift_b_final_dist, flat_swap_shift_b_final_dist],
        colors, labels=None, title='', rad=False)

    for ax in axs[:3, :].flat:
        ax.set_xlabel('')
    for ax in axs[4]:
        ax.set_visible(False)

    misc.save_figure(fig, filename, save_dir=save_dir)

    plt.close('all')
def main():
    all_grps = df.loadExptGrps('GOL')

    WT_expt_grp = all_grps['WT_place_set']
    Df_expt_grp = all_grps['Df_place_set']
    expt_grps = [WT_expt_grp, Df_expt_grp]

    fig, axs = plt.subplots(4,
                            3,
                            figsize=(8.5, 11),
                            gridspec_kw={
                                'wspace': 0.4,
                                'hspace': 0.3
                            })

    #
    # Velocity
    #
    plotting.plot_metric(axs[0, 0],
                         expt_grps,
                         metric_fn=lab.ExperimentGroup.velocity_dataframe,
                         groupby=[['expt'], ['mouseID']],
                         plotby=None,
                         plot_method='swarm',
                         activity_kwargs=None,
                         filter_fn=lambda df: df['condition'] == 'A',
                         filter_columns=['condition'],
                         activity_label='Velocity (cm/s)',
                         colors=colors,
                         plot_bar=True,
                         edgecolor='k',
                         linewidth=0.5,
                         return_full_dataframes=False)
    plotting.plot_metric(axs[0, 1],
                         expt_grps,
                         metric_fn=lab.ExperimentGroup.velocity_dataframe,
                         groupby=[['expt'], ['session_in_day', 'mouseID']],
                         plotby=['session_in_day'],
                         plot_method='swarm',
                         activity_kwargs=None,
                         rotate_labels=False,
                         filter_fn=lambda df: df['condition'] == 'A',
                         filter_columns=['condition'],
                         activity_label='Velocity (cm/s)',
                         colors=colors,
                         plot_bar=True,
                         edgecolor='k',
                         linewidth=0.5)

    #
    # Lap Rate
    #
    plotting.plot_metric(axs[1, 0],
                         expt_grps,
                         metric_fn=lab.ExperimentGroup.number_of_laps,
                         groupby=[['expt'], ['mouseID']],
                         plotby=None,
                         plot_method='swarm',
                         activity_kwargs={'rate': True},
                         filter_fn=lambda df: df['condition'] == 'A',
                         filter_columns=['condition'],
                         activity_label='Lap rate (1/min)',
                         colors=colors,
                         agg_fn=[np.sum, np.mean],
                         plot_bar=True,
                         edgecolor='k',
                         linewidth=0.5,
                         return_full_dataframes=False)
    plotting.plot_metric(axs[1, 1],
                         expt_grps,
                         metric_fn=lab.ExperimentGroup.number_of_laps,
                         groupby=[['expt'], ['session_in_day', 'mouseID']],
                         plotby=['session_in_day'],
                         plot_method='swarm',
                         activity_kwargs={'rate': True},
                         rotate_labels=False,
                         filter_fn=lambda df: df['condition'] == 'A',
                         filter_columns=['condition'],
                         activity_label='Lap rate (1/min)',
                         colors=colors,
                         agg_fn=[np.sum, np.mean],
                         plot_bar=True,
                         edgecolor='k',
                         linewidth=0.5)

    #
    # Lick Rate
    #
    plotting.plot_metric(axs[2, 0],
                         expt_grps,
                         metric_fn=lab.ExperimentGroup.behavior_dataframe,
                         groupby=[['trial'], ['mouseID']],
                         plotby=None,
                         plot_method='swarm',
                         activity_kwargs={
                             'key': 'licking',
                             'rate': True
                         },
                         filter_fn=lambda df: df['condition'] == 'A',
                         filter_columns=['condition'],
                         activity_label='Lick rate (Hz)',
                         colors=colors,
                         agg_fn=[np.sum, np.mean],
                         plot_bar=True,
                         edgecolor='k',
                         linewidth=0.5,
                         return_full_dataframes=False)
    plotting.plot_metric(axs[2, 1],
                         expt_grps,
                         metric_fn=lab.ExperimentGroup.behavior_dataframe,
                         groupby=[['trial'], ['session_in_day', 'mouseID']],
                         plotby=['session_in_day'],
                         plot_method='swarm',
                         roi_filters=None,
                         activity_kwargs={
                             'key': 'licking',
                             'rate': True
                         },
                         filter_fn=lambda df: df['condition'] == 'A',
                         filter_columns=['condition'],
                         activity_label='Lick rate (Hz)',
                         colors=colors,
                         agg_fn=[np.sum, np.mean],
                         rotate_labels=False,
                         plot_bar=True,
                         edgecolor='k',
                         linewidth=0.5)

    for ax in axs[3, :]:
        ax.set_visible(False)
    for ax in axs[:, 2]:
        ax.set_visible(False)

    sns.despine(fig=fig)

    for ax in axs.flat:
        ax.set_title('')
    for ax in axs[:, 0]:
        ax.set_xlabel('')
        ax.tick_params(bottom=False, labelbottom=False, length=3, pad=2)

    for ax in axs[:, 1]:
        ax.set_xlabel('')
        ax.set_ylabel('')
        ax.set_xticklabels(['1', '2', '3'])
        ax.tick_params(length=3, pad=2)
    for ax in axs[:3, :2].flat:
        ax.get_legend().set_visible(False)

    plotting.stackedText(axs[0, 0], [WT_label, Df_label],
                         colors=colors,
                         loc=2,
                         size=10)

    axs[2, 1].set_xlabel('Session in day')

    for ax in axs[0, :]:
        ax.set_ylim(0, 13)
    for ax in axs[1, :]:
        ax.set_ylim(0, 2.0)
        ax.set_yticks([0, 0.5, 1.0, 1.5, 2.0])
    for ax in axs[2, :]:
        ax.set_ylim(0, 3)

    misc.save_figure(fig, filename, save_dir=save_dir)

    plt.close('all')
def main():

    fig = plt.figure(figsize=(8.5, 11))

    gs = plt.GridSpec(
        3, 2, left=0.1, bottom=0.4, right=0.5, top=0.9, hspace=0.4, wspace=0.4)
    shift_schematic_ax = fig.add_subplot(gs[0, 1])
    model_enrich_by_time_ax = fig.add_subplot(gs[1, 0])
    model_final_enrich_ax = fig.add_subplot(gs[1, 1])
    model_swap_compare_enrich_ax = fig.add_subplot(gs[2, 0])
    model_WT_swap_final_enrich_ax = fig.add_subplot(gs[2, 1])

    #
    # Shift schematic
    #

    mu = 1
    k = 2.
    pos1 = 1.5

    def vms(x, mu=1, k=2.):
        return np.exp(k * np.cos(x - mu)) / (2 * np.pi * i0(k))

    xr = np.linspace(-np.pi, np.pi, 1000)

    shift_schematic_ax.plot(xr, [vms(x) for x in xr], color='k')
    shift_schematic_ax.axvline(pos1, color='r', ls='--')
    shift_schematic_ax.plot(
        [mu, mu], [0, vms(mu, mu=mu, k=k)], ls=':', color='0.3')
    sns.despine(ax=shift_schematic_ax, top=True, right=True)
    shift_schematic_ax.tick_params(labelleft=True, left=True, direction='out')
    shift_schematic_ax.set_xticks([-np.pi, -np.pi / 2, 0, np.pi / 2, np.pi])
    shift_schematic_ax.set_xticklabels(['-0.50', '-0.25', '0', '0.25', '0.50'])
    shift_schematic_ax.set_xlim(-np.pi, np.pi)
    shift_schematic_ax.set_yticks([0, 0.2, 0.4, 0.6])
    shift_schematic_ax.set_ylim(0, 0.6)
    shift_schematic_ax.set_xlabel('Distance from reward (fraction of belt)')
    shift_schematic_ax.set_ylabel(
        'New place field centroid\nprobability density')

    #
    # Distance to reward
    #

    m = pickle.load(open(simulations_path))

    WT_enrich = emp.calc_enrichment(m['WT_no_swap_pos'], m['WT_no_swap_masks'])
    flat_enrich = emp.calc_enrichment(m['Df_no_swap_pos'], m['Df_no_swap_masks'])

    emp.plot_enrichment(
        model_enrich_by_time_ax, WT_enrich, WT_color, title='', rad=False)
    emp.plot_enrichment(
        model_enrich_by_time_ax, flat_enrich, flat_color, title='', rad=False)

    model_enrich_by_time_ax.set_xlabel("Iteration ('session' #)")
    plotting.stackedText(
        model_enrich_by_time_ax, [WT_label, flat_label], colors=colors, loc=2,
        size=10)

    #
    # Calc final distributions
    #

    WT_no_swap_final_dist = emp.calc_final_distributions(
        m['WT_no_swap_pos'], m['WT_no_swap_masks'])
    flat_no_swap_final_dist = emp.calc_final_distributions(
        m['Df_no_swap_pos'], m['Df_no_swap_masks'])

    WT_swap_recur_final_dist = emp.calc_final_distributions(
        m['WT_swap_recur_pos'], m['WT_swap_recur_masks'])
    flat_swap_recur_final_dist = emp.calc_final_distributions(
        m['Df_swap_recur_pos'], m['Df_swap_recur_masks'])

    WT_swap_shift_b_final_dist = emp.calc_final_distributions(
        m['WT_swap_shift_b_pos'], m['WT_swap_shift_b_masks'])
    flat_swap_shift_b_final_dist = emp.calc_final_distributions(
        m['Df_swap_shift_b_pos'], m['Df_swap_shift_b_masks'])

    WT_swap_shift_k_final_dist = emp.calc_final_distributions(
        m['WT_swap_shift_k_pos'], m['WT_swap_shift_k_masks'])
    flat_swap_shift_k_final_dist = emp.calc_final_distributions(
        m['Df_swap_shift_k_pos'], m['Df_swap_shift_k_masks'])

    #
    # Final distribution
    #

    emp.plot_final_distributions(
        model_final_enrich_ax,
        [WT_no_swap_final_dist, flat_no_swap_final_dist],
        colors, labels=labels, title='', rad=False)

    plotting.stackedText(
        model_final_enrich_ax, [WT_label, flat_label], colors=colors, loc=2,
        size=10)
    #
    # Compare enrichment
    #

    WT_bars = [
        WT_no_swap_final_dist, WT_swap_recur_final_dist,
        WT_swap_shift_k_final_dist, WT_swap_shift_b_final_dist]
    flat_bars = [
        flat_no_swap_final_dist, flat_swap_recur_final_dist,
        flat_swap_shift_k_final_dist, flat_swap_shift_b_final_dist]

    WT_bars = [np.pi / 2 - np.abs(bar) for bar in WT_bars]
    flat_bars = [np.pi / 2 - np.abs(bar) for bar in flat_bars]

    plotting.grouped_bar(
        model_swap_compare_enrich_ax, [WT_bars, flat_bars],
        [WT_label, flat_label],
        ['No\nswap', 'Swap\n' + r'$P_{recur}$',
         'Swap\nvariance', 'Swap\noffset'], [WT_color, flat_color],
        error_bars=None)

    sns.despine(ax=model_swap_compare_enrich_ax)
    model_swap_compare_enrich_ax.tick_params(length=3, pad=2, direction='out')
    model_swap_compare_enrich_ax.set_ylabel('Final enrichment (rad)')
    model_swap_compare_enrich_ax.get_legend().set_visible(False)
    model_swap_compare_enrich_ax.set_ylim(0, 0.08 * 2 * np.pi)
    y_ticks = np.array(['0', '0.02', '0.04', '0.06', '0.08'])
    model_swap_compare_enrich_ax.set_yticks(
        y_ticks.astype('float') * 2 * np.pi)
    model_swap_compare_enrich_ax.set_yticklabels(y_ticks)
    plotting.stackedText(
        model_swap_compare_enrich_ax, [WT_label, flat_label], colors=colors,
        loc=2, size=10)

    #
    # WT swap final distributions
    #

    hist_colors = iter(sns.color_palette())

    emp.plot_final_distributions(
        model_WT_swap_final_enrich_ax,
        [WT_no_swap_final_dist,
         WT_swap_recur_final_dist, WT_swap_shift_k_final_dist,
         WT_swap_shift_b_final_dist], colors=hist_colors,
        labels=['No swap', r'Swap $P_{recur}$',
                'Swap shift variance', 'Swap shift offset'], rad=False)
    model_WT_swap_final_enrich_ax.legend(
        loc='lower right', fontsize=5, frameon=False)

    misc.save_figure(fig, filename, save_dir=save_dir)

    plt.close('all')
Esempio n. 5
0
def main():

    fig = plt.figure(figsize=(8.5, 11))

    gs1 = plt.GridSpec(2,
                       2,
                       left=0.1,
                       right=0.7,
                       top=0.9,
                       bottom=0.5,
                       hspace=0.4,
                       wspace=0.4)
    WT_enrich_ax = fig.add_subplot(gs1[0, 0])
    Df_enrich_ax = fig.add_subplot(gs1[0, 1])
    WT_final_dist_ax = fig.add_subplot(gs1[1, 0])
    Df_final_dist_ax = fig.add_subplot(gs1[1, 1])

    simulations_path_A = os.path.join(
        df.data_path, 'enrichment_model',
        'WT_Df_enrichment_model_simulation_A.pkl')
    simulations_path_B = os.path.join(
        df.data_path, 'enrichment_model',
        'WT_Df_enrichment_model_simulation_B.pkl')
    simulations_path_C = os.path.join(
        df.data_path, 'enrichment_model',
        'WT_Df_enrichment_model_simulation_C.pkl')

    m_A = pickle.load(open(simulations_path_A))
    m_B = pickle.load(open(simulations_path_B))
    m_C = pickle.load(open(simulations_path_C))

    WT_colors = sns.light_palette(WT_color, 7)[2::2]
    Df_colors = sns.light_palette(Df_color, 7)[2::2]

    condition_labels = [
        r'Condition $\mathrm{I}$', r'Condition $\mathrm{II}$',
        r'Condition $\mathrm{III}$'
    ]

    WT_final_dists, Df_final_dists = [], []

    for m, WT_c, Df_c in zip((m_A, m_B, m_C), WT_colors, Df_colors):

        WT_enrich = emp.calc_enrichment(m['WT_no_swap_pos'],
                                        m['WT_no_swap_masks'])
        Df_enrich = emp.calc_enrichment(m['Df_no_swap_pos'],
                                        m['Df_no_swap_masks'])

        WT_final_dists.append(
            emp.calc_final_distributions(m['WT_no_swap_pos'],
                                         m['WT_no_swap_masks']))
        Df_final_dists.append(
            emp.calc_final_distributions(m['Df_no_swap_pos'],
                                         m['Df_no_swap_masks']))

        emp.plot_enrichment(WT_enrich_ax, WT_enrich, WT_c, title='', rad=False)
        emp.plot_enrichment(Df_enrich_ax, Df_enrich, Df_c, title='', rad=False)

    WT_enrich_ax.set_xlabel("Iteration ('session' #)")
    Df_enrich_ax.set_xlabel("Iteration ('session' #)")
    plotting.stackedText(WT_enrich_ax,
                         condition_labels,
                         colors=WT_colors,
                         loc=2,
                         size=8)
    plotting.stackedText(Df_enrich_ax,
                         condition_labels,
                         colors=Df_colors,
                         loc=2,
                         size=8)

    emp.plot_final_distributions(WT_final_dist_ax,
                                 WT_final_dists,
                                 WT_colors,
                                 labels=condition_labels,
                                 title='',
                                 rad=False)
    emp.plot_final_distributions(Df_final_dist_ax,
                                 Df_final_dists,
                                 Df_colors,
                                 labels=condition_labels,
                                 title='',
                                 rad=False)

    WT_final_dist_ax.set_xlabel('Distance from reward\n(fraction of belt)')
    Df_final_dist_ax.set_xlabel('Distance from reward\n(fraction of belt)')
    plotting.stackedText(WT_final_dist_ax,
                         condition_labels,
                         colors=WT_colors,
                         loc=2,
                         size=8)
    plotting.stackedText(Df_final_dist_ax,
                         condition_labels,
                         colors=Df_colors,
                         loc=2,
                         size=8)
    WT_final_dist_ax.set_yticks([0, 0.1, 0.2, 0.3])
    Df_final_dist_ax.set_yticks([0, 0.1, 0.2, 0.3])

    save_figure(fig, filename, save_dir=save_dir)
Esempio n. 6
0
def main():
    all_grps = df.loadExptGrps('GOL')

    expt_grps = [all_grps['WT_place_set'], all_grps['Df_place_set']]

    pc_filters = [expt_grp.pcs_filter(roi_filter=roi_filter) for
                  expt_grp, roi_filter in zip(expt_grps, roi_filters)]

    fig = plt.figure(figsize=(8.5, 11))

    gs1 = plt.GridSpec(
        3, 3, left=0.1, right=0.9, top=0.9, bottom=0.3, wspace=0.5, hspace=0.4)
    sensitivity_cdf_ax = fig.add_subplot(gs1[1, 0])
    sensitivity_bar_ax = fig.add_axes([0.24, 0.535, 0.05, 0.07])

    specificity_cdf_ax = fig.add_subplot(gs1[1, 1])
    specificity_bar_ax = fig.add_axes([0.43, 0.60, 0.05, 0.07])

    width_cdf_ax = fig.add_subplot(gs1[2, 0])
    width_bar_ax = fig.add_axes([0.24, 0.315, 0.05, 0.07])

    sparsity_cdf_ax = fig.add_subplot(gs1[2, 1])
    sparsity_bar_ax = fig.add_axes([0.54, 0.315, 0.05, 0.07])

    is_ever_pc_fraction_ax = fig.add_subplot(gs1[0, 0])

    pc_fraction_ax = fig.add_subplot(gs1[0, 1])
    pc_fraction_bar_ax = fig.add_axes([0.54, 0.755, 0.05, 0.07])

    cdf_axs = [
        sensitivity_cdf_ax, specificity_cdf_ax, width_cdf_ax, sparsity_cdf_ax,
        pc_fraction_ax]
    bar_axs = [
        sensitivity_bar_ax, specificity_bar_ax, width_bar_ax, sparsity_bar_ax,
        pc_fraction_bar_ax]

    sensitivity_range = (0, 1)
    plotting.plot_metric(
        sensitivity_cdf_ax, expt_grps, metric_fn=place.sensitivity,
        groupby=[['roi_id', 'expt']], plotby=None, plot_method='cdf',
        roi_filters=pc_filters, activity_kwargs=None,
        activity_label='Transient sensitivity', colors=colors,
        rotate_labels=False, linestyles=linestyles)
    sensitivity_cdf_ax.set_xlim(sensitivity_range)
    plotting.plot_metric(
        sensitivity_bar_ax, expt_grps, metric_fn=place.sensitivity,
        groupby=[['roi_id', 'expt'],
                 ['roi_id', 'uniqueLocationKey', 'mouseID'], ['mouseID']],
        plotby=None, plot_method='grouped_bar',
        roi_filters=pc_filters, activity_kwargs=None,
        activity_label='Transient sensitivity', colors=colors)
    sensitivity_bar_ax.set_ylim(sensitivity_range)
    sensitivity_bar_ax.set_yticks(sensitivity_range)

    plotting.stackedText(
        sensitivity_cdf_ax, [WT_label, Df_label], colors=colors, loc=2,
        size=10)

    specificity_range = (0, 1)
    plotting.plot_metric(
        specificity_cdf_ax, expt_grps, metric_fn=place.specificity,
        groupby=[['roi_id', 'expt']], plotby=None, plot_method='cdf',
        roi_filters=pc_filters, activity_kwargs=None,
        activity_label='Transient specificity', colors=colors,
        rotate_labels=False, linestyles=linestyles)
    specificity_cdf_ax.set_xlim(specificity_range)
    plotting.plot_metric(
        specificity_bar_ax, expt_grps, metric_fn=place.specificity,
        groupby=[['roi_id', 'expt'],
                 ['roi_id', 'uniqueLocationKey', 'mouseID'], ['mouseID']],
        plotby=None, plot_method='grouped_bar',
        roi_filters=pc_filters, activity_kwargs=None,
        activity_label='Transient specificity', colors=colors)
    specificity_bar_ax.set_ylim(specificity_range)
    specificity_bar_ax.set_yticks(specificity_range)

    width_range = (0, 100)
    plotting.plot_metric(
        width_cdf_ax, expt_grps, metric_fn=place.place_field_width,
        groupby=None, plotby=None, plot_method='cdf',
        roi_filters=pc_filters, activity_kwargs=None,
        activity_label='Place field width (cm)', colors=colors,
        rotate_labels=False, linestyles=linestyles)
    width_cdf_ax.set_xlim(width_range)
    plotting.plot_metric(
        width_bar_ax, expt_grps, metric_fn=place.place_field_width,
        groupby=[['roi_id', 'uniqueLocationKey', 'mouseID'], ['mouseID']],
        plotby=None, plot_method='grouped_bar',
        roi_filters=pc_filters, activity_kwargs=None,
        activity_label='Place field width', colors=colors)
    width_bar_ax.set_ylim(width_range)
    width_bar_ax.set_yticks(width_range)

    sparsity_range = (0, 1)
    plotting.plot_metric(
        sparsity_cdf_ax, expt_grps, metric_fn=place.sparsity,
        groupby=[['roi_id', 'expt']], plotby=None, plot_method='cdf',
        roi_filters=pc_filters, activity_kwargs=None,
        activity_label='Single-cell sparsity', colors=colors,
        rotate_labels=False, linestyles=linestyles)
    sparsity_cdf_ax.set_xlim(sparsity_range)
    plotting.plot_metric(
        sparsity_bar_ax, expt_grps, metric_fn=place.sparsity,
        groupby=[['roi_id', 'expt'],
                 ['roi_id', 'uniqueLocationKey', 'mouseID'], ['mouseID']],
        plotby=None, plot_method='grouped_bar',
        roi_filters=pc_filters, activity_kwargs=None,
        activity_label='Single-cell sparsity', colors=colors)
    sparsity_bar_ax.set_ylim(sparsity_range)
    sparsity_bar_ax.set_yticks(sparsity_range)

    fraction_ses_pc_range = (0, 0.5)
    plotting.plot_metric(
        pc_fraction_ax, expt_grps, metric_fn=lab.ExperimentGroup.filtered_rois,
        groupby=(('mouseID', 'uniqueLocationKey', 'roi_id'),), plotby=None,
        colorby=None, plot_method='cdf', roi_filters=pc_filters,
        activity_kwargs=[
            {'include_roi_filter': roi_filter} for roi_filter in roi_filters],
        colors=colors, activity_label='Fraction of sessions a place cell',
        rotate_labels=False, linestyles=linestyles)
    pc_fraction_ax.tick_params(length=3, pad=2)
    pc_fraction_ax.get_legend().set_visible(False)
    pc_fraction_ax.set_title('')
    pc_fraction_ax.set_xticks([0, .2, .4, .6, .8])
    pc_fraction_ax.set_xlim(0, .8)
    pc_fraction_ax.spines['left'].set_linewidth(1)
    pc_fraction_ax.spines['bottom'].set_linewidth(1)

    plotting.plot_metric(
        pc_fraction_bar_ax, expt_grps,
        metric_fn=lab.ExperimentGroup.filtered_rois,
        groupby=[['roi_id', 'uniqueLocationKey', 'mouseID'], ['mouseID']],
        plotby=None, plot_method='grouped_bar',
        roi_filters=pc_filters, activity_kwargs=[
            {'include_roi_filter': roi_filter} for roi_filter in roi_filters],
        activity_label='Fraction of sessions a place cell', colors=colors)
    pc_fraction_bar_ax.set_ylim(fraction_ses_pc_range)
    pc_fraction_bar_ax.set_yticks(fraction_ses_pc_range)

    place.is_ever_place_cell(
        expt_grps, roi_filters=roi_filters, ax=is_ever_pc_fraction_ax,
        colors=colors, filter_fn=lambda df: df['session_number'] < 15,
        filter_columns=['session_number'],
        groupby=[['mouseID', 'session_number']])
    is_ever_pc_fraction_ax.get_legend().set_visible(False)
    is_ever_pc_fraction_ax.tick_params(length=3, pad=2)
    is_ever_pc_fraction_ax.set_xlabel('Session number')
    is_ever_pc_fraction_ax.set_title('Lifetime place coding')
    is_ever_pc_fraction_ax.set_xticklabels([
        '1', '', '3', '', '5', '', '7', '', '9', '', '11', '', '13', '', '15'])

    for ax in cdf_axs + bar_axs:
        ax.tick_params(length=3, pad=2)
    for ax in cdf_axs:
        ax.set_title('')
    for ax in bar_axs:
        ax.set_xlabel('')
        ax.set_ylabel('')
        ax.set_title('')
        ax.set_xticks([])
        ax.tick_params(bottom=False, labelbottom=False, length=3, pad=2)
        ax.get_legend().set_visible(False)

    sns.despine(fig)

    misc.save_figure(fig, filename, save_dir=save_dir)

    plt.close('all')