def _recur_by_position(self, positions):
        recur_theta = self.params['position_recurrence']['theta']
        recur_knots = np.linspace(
            -np.pi, np.pi, self.params['position_recurrence']['n_knots'])
        recur_spline = splines.CyclicSpline(recur_knots)
        N = recur_spline.design_matrix(positions)

        return splines.prob(recur_theta, N)
    def _shift_mean_var(self, positions):
        shift_knots = self.params['position_stability']['all_pairs']['knots']
        shift_spline = splines.CyclicSpline(shift_knots)

        shift_theta_b = self.params['position_stability']['all_pairs'][
            'theta_b']
        shift_theta_k = self.params['position_stability']['all_pairs'][
            'theta_k']

        N = shift_spline.design_matrix(positions)
        bs = np.dot(N, shift_theta_b)
        ks = splines.get_k(shift_theta_k, N)

        return bs, ks
def main():

    WT_params = pkl.load(open(WT_params_path))
    Df_params = pkl.load(open(Df_params_path))

    WT_raw_data, WT_data = emd.load_data('wt', root=data_root_path)
    Df_raw_data, Df_data = emd.load_data('df', root=data_root_path)

    fig, axs = plt.subplots(3,
                            3,
                            figsize=(11, 8.5),
                            gridspec_kw={'wspace': 0.4})

    for ax in axs[1:, :].flat:
        ax.set_visible(False)

    #
    # 2 session stability
    #

    WT_2_shift_data = emd.tripled_activity_centroid_distance_to_reward(
        WT_data, prev_imaged=False)

    WT_2_shift_data = WT_2_shift_data.dropna(subset=['first', 'third'])
    WT_2_shifts = WT_2_shift_data['third'] - WT_2_shift_data['first']
    WT_2_shifts[WT_2_shifts < -np.pi] += 2 * np.pi
    WT_2_shifts[WT_2_shifts >= np.pi] -= 2 * np.pi

    WT_2_shift_data_prev = emd.tripled_activity_centroid_distance_to_reward(
        WT_data, prev_imaged=True)
    WT_2_shift_data_prev = WT_2_shift_data_prev.dropna(
        subset=['first', 'third'])
    WT_2_shifts_prev = WT_2_shift_data_prev['third'] - \
        WT_2_shift_data_prev['first']
    WT_2_shifts_prev[WT_2_shifts_prev < -np.pi] += 2 * np.pi
    WT_2_shifts_prev[WT_2_shifts_prev >= np.pi] -= 2 * np.pi
    WT_2_npc = np.isnan(WT_2_shift_data_prev['second'])

    sns.regplot(WT_2_shift_data_prev['first'][WT_2_npc],
                WT_2_shifts_prev[WT_2_npc],
                ax=axs[0, 0],
                color='m',
                fit_reg=False,
                scatter_kws={'s': 7},
                marker='x')
    sns.regplot(WT_2_shift_data['first'],
                WT_2_shifts,
                ax=axs[0, 0],
                color=WT_color,
                fit_reg=False,
                scatter_kws={'s': 3},
                marker=WT_marker)
    axs[0, 0].axvline(ls='--', color='0.4', lw=0.5)
    axs[0, 0].axhline(ls='--', color='0.4', lw=0.5)
    axs[0, 0].plot([-np.pi, np.pi], [np.pi, -np.pi], color='g', ls=':', lw=2)
    axs[0, 0].tick_params(length=3, pad=1, top=False)
    axs[0, 0].set_xlabel('Initial distance from reward\n(fraction of belt)')
    axs[0, 0].set_ylabel(r'Two-session $\Delta$ position' +
                         '\n(fraction of belt)')
    axs[0, 0].set_xlim(-np.pi, np.pi)
    axs[0, 0].set_xticks([-np.pi, -np.pi / 2, 0, np.pi / 2, np.pi])
    axs[0, 0].set_xticklabels(['-0.50', '-0.25', '0', '0.25', '0.50'])
    axs[0, 0].set_ylim(-np.pi, np.pi)
    axs[0, 0].set_yticks([-np.pi, -np.pi / 2, 0, np.pi / 2, np.pi])
    axs[0, 0].set_yticklabels(['-0.50', '-0.25', '0', '0.25', '0.50'])
    axs[0, 0].set_title(WT_label)

    Df_2_shift_data = emd.tripled_activity_centroid_distance_to_reward(
        Df_data, prev_imaged=False)
    Df_2_shift_data = Df_2_shift_data.dropna(subset=['first', 'third'])
    Df_2_shifts = Df_2_shift_data['third'] - Df_2_shift_data['first']
    Df_2_shifts[Df_2_shifts < -np.pi] += 2 * np.pi
    Df_2_shifts[Df_2_shifts >= np.pi] -= 2 * np.pi

    Df_2_shift_data_prev = emd.tripled_activity_centroid_distance_to_reward(
        Df_data, prev_imaged=True)
    Df_2_shift_data_prev = Df_2_shift_data_prev.dropna(
        subset=['first', 'third'])
    Df_2_shifts_prev = Df_2_shift_data_prev['third'] - \
        Df_2_shift_data_prev['first']
    Df_2_shifts_prev[Df_2_shifts_prev < -np.pi] += 2 * np.pi
    Df_2_shifts_prev[Df_2_shifts_prev >= np.pi] -= 2 * np.pi
    Df_2_npc = np.isnan(Df_2_shift_data_prev['second'])

    sns.regplot(Df_2_shift_data_prev['first'][Df_2_npc],
                Df_2_shifts_prev[Df_2_npc],
                ax=axs[0, 1],
                color='c',
                fit_reg=False,
                scatter_kws={'s': 7},
                marker='x')
    sns.regplot(Df_2_shift_data['first'],
                Df_2_shifts,
                ax=axs[0, 1],
                color=Df_color,
                fit_reg=False,
                scatter_kws={'s': 3},
                marker=Df_marker)
    axs[0, 1].axvline(ls='--', color='0.4', lw=0.5)
    axs[0, 1].axhline(ls='--', color='0.4', lw=0.5)
    axs[0, 1].plot([-np.pi, np.pi], [np.pi, -np.pi], color='g', ls=':', lw=2)
    axs[0, 1].tick_params(length=3, pad=1, top=False)
    axs[0, 1].set_xlabel('Initial distance from reward\n(fraction of belt)')
    axs[0, 1].set_ylabel(r'Two-session $\Delta$ position' +
                         '\n(fraction of belt)')
    axs[0, 1].set_xlim(-np.pi, np.pi)
    axs[0, 1].set_xticks([-np.pi, -np.pi / 2, 0, np.pi / 2, np.pi])
    axs[0, 1].set_xticklabels(['-0.50', '-0.25', '0', '0.25', '0.50'])
    axs[0, 1].set_ylim(-np.pi, np.pi)
    axs[0, 1].set_yticks([-np.pi, -np.pi / 2, 0, np.pi / 2, np.pi])
    axs[0, 1].set_yticklabels(['-0.50', '-0.25', '0', '0.25', '0.50'])
    axs[0, 1].set_title(Df_label)

    #
    # Place field stability k
    #

    x_vals = np.linspace(-np.pi, np.pi, 1000)

    WT_knots = WT_params['position_stability']['all_pairs']['knots']
    WT_spline = splines.CyclicSpline(WT_knots)
    WT_N = WT_spline.design_matrix(x_vals)

    WT_all_theta_k = WT_params['position_stability']['all_pairs'][
        'boots_theta_k']
    WT_skip_1_theta_k = WT_params['position_stability']['skip_one'][
        'boots_theta_k']
    WT_skip_npc_theta_k = WT_params['position_stability']['skip_one_npc'][
        'boots_theta_k']
    WT_two_iter_theta_k = WT_params['position_stability']['two_iter'][
        'boots_theta_k']

    WT_all_k_fit_mean = [
        1. / splines.get_k(theta_k, WT_N).mean() for theta_k in WT_all_theta_k
    ]
    WT_skip_1_k_fit_mean = [
        1. / splines.get_k(theta_k, WT_N).mean()
        for theta_k in WT_skip_1_theta_k
    ]
    WT_skip_npc_k_fit_mean = [
        1. / splines.get_k(theta_k, WT_N).mean()
        for theta_k in WT_skip_npc_theta_k
    ]
    WT_two_iter_k_fit_mean = [
        1. / splines.get_k(theta_k, WT_N).mean()
        for theta_k in WT_two_iter_theta_k
    ]

    WT_all_k_fit_df = pd.DataFrame({
        'value': WT_all_k_fit_mean,
        'genotype': WT_label
    })
    WT_skip_1_k_fit_df = pd.DataFrame({
        'value': WT_skip_1_k_fit_mean,
        'genotype': WT_label
    })
    WT_skip_npc_k_fit_df = pd.DataFrame({
        'value': WT_skip_npc_k_fit_mean,
        'genotype': WT_label
    })
    WT_two_iter_k_fit_df = pd.DataFrame({
        'value': WT_two_iter_k_fit_mean,
        'genotype': WT_label
    })

    Df_knots = Df_params['position_stability']['all_pairs']['knots']
    Df_spline = splines.CyclicSpline(Df_knots)
    Df_N = Df_spline.design_matrix(x_vals)

    Df_all_theta_k = Df_params['position_stability']['all_pairs'][
        'boots_theta_k']
    Df_skip_1_theta_k = Df_params['position_stability']['skip_one'][
        'boots_theta_k']
    Df_skip_npc_theta_k = Df_params['position_stability']['skip_one_npc'][
        'boots_theta_k']
    Df_two_iter_theta_k = Df_params['position_stability']['two_iter'][
        'boots_theta_k']

    Df_all_k_fit_mean = [
        1. / splines.get_k(theta_k, Df_N).mean() for theta_k in Df_all_theta_k
    ]
    Df_skip_1_k_fit_mean = [
        1. / splines.get_k(theta_k, Df_N).mean()
        for theta_k in Df_skip_1_theta_k
    ]
    Df_skip_npc_k_fit_mean = [
        1. / splines.get_k(theta_k, Df_N).mean()
        for theta_k in Df_skip_npc_theta_k
    ]
    Df_two_iter_k_fit_mean = [
        1. / splines.get_k(theta_b, Df_N).mean()
        for theta_b in Df_two_iter_theta_k
    ]

    Df_all_k_fit_df = pd.DataFrame({
        'value': Df_all_k_fit_mean,
        'genotype': Df_label
    })
    Df_skip_1_k_fit_df = pd.DataFrame({
        'value': Df_skip_1_k_fit_mean,
        'genotype': Df_label
    })
    Df_skip_npc_k_fit_df = pd.DataFrame({
        'value': Df_skip_npc_k_fit_mean,
        'genotype': Df_label
    })
    Df_two_iter_k_fit_df = pd.DataFrame({
        'value': Df_two_iter_k_fit_mean,
        'genotype': Df_label
    })

    all_k_df = pd.concat([WT_all_k_fit_df, Df_all_k_fit_df])
    skip_1_k_df = pd.concat([WT_skip_1_k_fit_df, Df_skip_1_k_fit_df])
    skip_npc_k_df = pd.concat([WT_skip_npc_k_fit_df, Df_skip_npc_k_fit_df])
    two_iter_k_df = pd.concat([WT_two_iter_k_fit_df, Df_two_iter_k_fit_df])

    order_dict = {WT_label: 0, Df_label: 1}
    all_k_df['order'] = all_k_df['genotype'].apply(order_dict.get)
    skip_1_k_df['order'] = skip_1_k_df['genotype'].apply(order_dict.get)
    skip_npc_k_df['order'] = skip_npc_k_df['genotype'].apply(order_dict.get)
    two_iter_k_df['order'] = two_iter_k_df['genotype'].apply(order_dict.get)

    plotting.plot_dataframe(
        axs[0, 2], [all_k_df, skip_1_k_df, skip_npc_k_df, two_iter_k_df],
        labels=[
            '1 ses elapsed', '2 ses elapsed (all)', '2 ses elapsed (nPC)',
            '2 iterations (model)'
        ],
        activity_label='Mean shift variance',
        colors=sns.color_palette('deep'),
        plotby=['genotype'],
        orderby='order',
        plot_method='grouped_bar',
        plot_shuffle=False,
        shuffle_plotby=False,
        pool_shuffle=False,
        agg_fn=np.mean,
        markers=None,
        label_groupby=True,
        z_score=False,
        normalize=False,
        error_bars='std')
    axs[0, 2].set_xlabel('')
    axs[0, 2].set_ylim(0, 1.5)
    y_ticks = np.array(['0', '0.01', '0.02', '0.03'])
    axs[0, 2].set_yticks(y_ticks.astype('float') * (2 * np.pi)**2)
    axs[0, 2].set_yticklabels(y_ticks)
    axs[0, 2].tick_params(top=False, bottom=False)
    axs[0, 2].set_title('')

    save_figure(fig, filename, save_dir=save_dir)
def main():

    raw_data, data = emd.load_data('wt',
                                   session_filter='C',
                                   root=os.path.join(df.data_path,
                                                     'enrichment_model'))
    expts = lab.ExperimentSet(os.path.join(df.metadata_path,
                                           'expt_metadata.xml'),
                              behaviorDataPath=os.path.join(
                                  df.data_path, 'behavior'),
                              dataPath=os.path.join(df.data_path, 'imaging'))

    params = pickle.load(open(params_path))

    fig = plt.figure(figsize=(8.5, 11))
    gs1 = plt.GridSpec(1,
                       2,
                       left=0.1,
                       bottom=0.65,
                       right=0.9,
                       top=0.9,
                       wspace=0.05)
    fov1_ax = fig.add_subplot(gs1[0, 0])
    fov2_ax = fig.add_subplot(gs1[0, 1])
    cmap_ax = fig.add_axes([0.49, 0.65, 0.02, 0.25])
    gs2 = plt.GridSpec(2,
                       2,
                       left=0.1,
                       bottom=0.3,
                       right=0.5,
                       top=0.6,
                       wspace=0.5,
                       hspace=0.5)
    recur_ax = fig.add_subplot(gs2[0, 0])
    shift_ax = fig.add_subplot(gs2[0, 1])
    shift_compare_ax = fig.add_subplot(gs2[1, 0])
    var_compare_ax = fig.add_subplot(gs2[1, 1])

    #
    # Tuning maps
    #

    e1 = expts.grabExpt('jz135', '2015-10-12-14h33m47s')
    e2 = expts.grabExpt('jz135', '2015-10-12-15h34m38s')

    cmap = mpl.colors.ListedColormap(sns.color_palette("husl", 256))

    for ax, expt in ((fov1_ax, e1), (fov2_ax, e2)):
        place.plot_spatial_tuning_overlay(ax,
                                          lab.classes.pcExperimentGroup(
                                              [expt], imaging_label='soma'),
                                          labels_visible=False,
                                          alpha=0.9,
                                          lw=0.1,
                                          cmap=cmap)
        plot_ROI_outlines(ax,
                          expt,
                          channel='Ch2',
                          label='soma',
                          roi_filter=None,
                          ls='-',
                          color='k',
                          lw=0.1)
        # Add a 50-um scale bar
        plotting.add_scalebar(
            ax=ax,
            matchx=False,
            matchy=False,
            sizey=0,
            sizex=50 / expt.imagingParameters()['micronsPerPixel']['XAxis'],
            bar_color='w',
            bar_thickness=3)

    fov1_ax.set_title('Session 1')
    fov2_ax.set_title('Session 2')

    gradient = np.linspace(0, 1, 256)
    gradient = np.vstack((gradient, gradient)).T
    cmap_ax.imshow(gradient, aspect='auto', cmap=cmap)
    sns.despine(ax=cmap_ax, top=True, left=True, right=True, bottom=True)
    cmap_ax.tick_params(left=False,
                        labelleft=False,
                        bottom=False,
                        labelbottom=False)
    cmap_ax.set_ylabel('belt position')

    # Figure out the reward window width
    reward_poss, windows = [], []
    for expt in [e1, e2]:
        reward_poss.append(expt.rewardPositions(units='normalized')[0])
        track_length = expt[0].behaviorData()['trackLength']
        window = float(expt.get('operantSpatialWindow'))
        windows.append(window / track_length)
    reward_pos = np.mean(reward_poss)
    window = np.mean(windows)

    # Add reward zone
    cmap_ax.plot([0, 1], [reward_pos, reward_pos],
                 transform=cmap_ax.transAxes,
                 color='k',
                 ls=':')
    cmap_ax.plot([0, 1], [reward_pos + window, reward_pos + window],
                 transform=cmap_ax.transAxes,
                 color='k',
                 ls=':')
    cmap_ax.set_ylim(0, 256)

    #
    # Recurrence by position
    #

    recur_x_vals = np.linspace(-np.pi, np.pi, 1000)

    recur_data = emd.recurrence_by_position(data, method='cv')
    recur_knots = np.linspace(-np.pi, np.pi,
                              params['position_recurrence']['n_knots'])
    recur_splines = splines.CyclicSpline(recur_knots)
    recur_n = recur_splines.design_matrix(recur_x_vals)

    recur_fit = splines.prob(params['position_recurrence']['theta'], recur_n)

    recur_boots_fits = [
        splines.prob(boot, recur_n)
        for boot in params['position_recurrence']['boots_theta']
    ]
    recur_ci_up_fit = np.percentile(recur_boots_fits, 95, axis=0)
    recur_ci_low_fit = np.percentile(recur_boots_fits, 5, axis=0)

    recur_ax.plot(recur_x_vals, recur_fit, color=WT_color)
    recur_ax.fill_between(recur_x_vals,
                          recur_ci_low_fit,
                          recur_ci_up_fit,
                          facecolor=WT_color,
                          alpha=0.5)
    sns.regplot(recur_data[:, 0],
                recur_data[:, 1],
                ax=recur_ax,
                color=WT_color,
                y_jitter=0.2,
                fit_reg=False,
                scatter_kws={'s': 1},
                marker=WT_marker)
    recur_ax.axvline(ls='--', color='0.4', lw=0.5)
    recur_ax.set_xlim(-np.pi, np.pi)
    recur_ax.set_xticks([-np.pi, -np.pi / 2, 0, np.pi / 2, np.pi])
    recur_ax.set_xticklabels(['-0.50', '-0.25', '0', '0.25', '0.50'])
    recur_ax.set_ylim(-0.3, 1.3)
    recur_ax.set_yticks([0, 0.5, 1])
    recur_ax.tick_params(length=3, pad=1, top=False)
    recur_ax.set_xlabel('Initial distance from reward (fraction of belt)')
    recur_ax.set_ylabel('Place cell recurrence probability')
    recur_ax.set_title('')
    recur_ax_2 = recur_ax.twinx()
    recur_ax_2.tick_params(length=3, pad=1, top=False)
    recur_ax_2.set_ylim(-0.3, 1.3)
    recur_ax_2.set_yticks([0, 1])
    recur_ax_2.set_yticklabels(['non-recur', 'recur'])

    #
    # Place field stability
    #

    shift_x_vals = np.linspace(-np.pi, np.pi, 1000)

    shift_knots = params['position_stability']['all_pairs']['knots']
    shift_spline = splines.CyclicSpline(shift_knots)
    shift_n = shift_spline.design_matrix(shift_x_vals)
    shift_theta_b = params['position_stability']['all_pairs']['theta_b']
    shift_b_fit = np.dot(shift_n, shift_theta_b)
    shift_theta_k = params['position_stability']['all_pairs']['theta_k']
    shift_k_fit = splines.get_k(shift_theta_k, shift_n)
    shift_fit_var = 1. / shift_k_fit

    shift_data = emd.paired_activity_centroid_distance_to_reward(data)
    shift_data = shift_data.dropna()
    shifts = shift_data['second'] - shift_data['first']
    shifts[shifts < -np.pi] += 2 * np.pi
    shifts[shifts >= np.pi] -= 2 * np.pi

    shift_ax.plot(shift_x_vals, shift_b_fit, color=WT_color)
    shift_ax.fill_between(shift_x_vals,
                          shift_b_fit - shift_fit_var,
                          shift_b_fit + shift_fit_var,
                          facecolor=WT_color,
                          alpha=0.5)
    sns.regplot(shift_data['first'],
                shifts,
                ax=shift_ax,
                color=WT_color,
                fit_reg=False,
                scatter_kws={'s': 1},
                marker=WT_marker)

    shift_ax.axvline(ls='--', color='0.4', lw=0.5)
    shift_ax.axhline(ls='--', color='0.4', lw=0.5)
    shift_ax.plot([-np.pi, np.pi], [np.pi, -np.pi], color='g', ls=':', lw=2)
    shift_ax.tick_params(length=3, pad=1, top=False)
    shift_ax.set_xlabel('Initial distance from reward (fraction of belt)')
    shift_ax.set_ylabel(r'$\Delta$ position (fraction of belt)')
    shift_ax.set_xlim(-np.pi, np.pi)
    shift_ax.set_xticks([-np.pi, -np.pi / 2, 0, np.pi / 2, np.pi])
    shift_ax.set_xticklabels(['-0.50', '-0.25', '0', '0.25', '0.50'])
    shift_ax.set_ylim(-np.pi, np.pi)
    shift_ax.set_yticks([-np.pi, -np.pi / 2, 0, np.pi / 2, np.pi])
    shift_ax.set_yticklabels(['-0.50', '-0.25', '0', '0.25', '0.50'])
    shift_ax.set_title('')

    #
    # Stability by distance to reward
    #

    shift_x_vals = np.linspace(-np.pi, np.pi, 1000)

    shift_knots = params['position_stability']['all_pairs']['knots']
    shift_spline = splines.CyclicSpline(shift_knots)
    shift_n = shift_spline.design_matrix(shift_x_vals)

    shift_theta_b = params['position_stability']['all_pairs']['theta_b']
    shift_b_fit = np.dot(shift_n, shift_theta_b)
    shift_boots_b_fit = [
        np.dot(shift_n, boot)
        for boot in params['position_stability']['all_pairs']['boots_theta_b']
    ]
    shift_b_ci_up_fit = np.percentile(shift_boots_b_fit, 95, axis=0)
    shift_b_ci_low_fit = np.percentile(shift_boots_b_fit, 5, axis=0)

    shift_compare_ax.plot(shift_x_vals, shift_b_fit, color=WT_color)
    shift_compare_ax.fill_between(shift_x_vals,
                                  shift_b_ci_low_fit,
                                  shift_b_ci_up_fit,
                                  facecolor=WT_color,
                                  alpha=0.5)

    shift_compare_ax.axvline(ls='--', color='0.4', lw=0.5)
    shift_compare_ax.axhline(ls='--', color='0.4', lw=0.5)
    shift_compare_ax.tick_params(length=3, pad=1, top=False)
    shift_compare_ax.set_xlim(-np.pi, np.pi)
    shift_compare_ax.set_xticks([-np.pi, -np.pi / 2, 0, np.pi / 2, np.pi])
    shift_compare_ax.set_xticklabels(['-0.50', '-0.25', '0', '0.25', '0.50'])
    shift_compare_ax.set_ylim(-0.10 * 2 * np.pi, 0.10 * 2 * np.pi)
    y_ticks = np.array(['-0.10', '-0.05', '0', '0.05', '0.10'])
    shift_compare_ax.set_yticks(y_ticks.astype('float') * 2 * np.pi)
    shift_compare_ax.set_yticklabels(y_ticks)
    shift_compare_ax.set_xlabel(
        'Initial distance from reward (fraction of belt)')
    shift_compare_ax.set_ylabel(r'$\Delta$ position (fraction of belt)')

    shift_theta_k = params['position_stability']['all_pairs']['theta_k']
    shift_k_fit = splines.get_k(shift_theta_k, shift_n)
    shift_boots_k_fit = [
        splines.get_k(boot, shift_n)
        for boot in params['position_stability']['all_pairs']['boots_theta_k']
    ]
    shift_k_ci_up_fit = np.percentile(shift_boots_k_fit, 95, axis=0)
    shift_k_ci_low_fit = np.percentile(shift_boots_k_fit, 5, axis=0)

    var_compare_ax.plot(shift_x_vals, 1. / shift_k_fit, color=WT_color)
    var_compare_ax.fill_between(shift_x_vals,
                                1. / shift_k_ci_low_fit,
                                1. / shift_k_ci_up_fit,
                                facecolor=WT_color,
                                alpha=0.5)

    var_compare_ax.axvline(ls='--', color='0.4', lw=0.5)
    var_compare_ax.tick_params(length=3, pad=1, top=False)
    var_compare_ax.set_xlim(-np.pi, np.pi)
    var_compare_ax.set_xticks([-np.pi, -np.pi / 2, 0, np.pi / 2, np.pi])
    var_compare_ax.set_xticklabels(['-0.50', '-0.25', '0', '0.25', '0.50'])
    y_ticks = np.array(['0.005', '0.010', '0.015', '0.020'])
    var_compare_ax.set_yticks(y_ticks.astype('float') * (2 * np.pi)**2)
    var_compare_ax.set_yticklabels(y_ticks)
    var_compare_ax.set_xlabel(
        'Initial distance from reward (fraction of belt)')
    var_compare_ax.set_ylabel(r'$\Delta$ position variance')

    misc.save_figure(fig, filename, save_dir=save_dir)

    plt.close('all')
Ejemplo n.º 5
0
def main():
    Df_raw_data, Df_data = emd.load_data('df',
                                         root=os.path.join(
                                             df.data_path, 'enrichment_model'))

    Df_params = pickle.load(open(params_path))

    fig = plt.figure(figsize=(8.5, 11))
    gs = plt.GridSpec(3,
                      2,
                      left=0.1,
                      bottom=0.5,
                      right=0.5,
                      top=0.9,
                      wspace=0.3,
                      hspace=0.3)
    Df_recur_ax = fig.add_subplot(gs[0, 0])
    Df_shift_ax = fig.add_subplot(gs[0, 1])
    shift_compare_ax = fig.add_subplot(gs[1, 0])
    var_compare_ax = fig.add_subplot(gs[1, 1])
    enrichment_ax = fig.add_subplot(gs[2, 0])
    final_enrichment_ax = fig.add_subplot(gs[2, 1])

    #
    # Recurrence by position
    #

    recur_x_vals = np.linspace(-np.pi, np.pi, 1000)

    Df_recur_data = emd.recurrence_by_position(Df_data, method='cv')
    Df_recur_knots = np.linspace(-np.pi, np.pi,
                                 Df_params['position_recurrence']['n_knots'])
    Df_recur_spline = splines.CyclicSpline(Df_recur_knots)
    Df_recur_N = Df_recur_spline.design_matrix(recur_x_vals)

    Df_recur_fit = splines.prob(Df_params['position_recurrence']['theta'],
                                Df_recur_N)

    Df_recur_boots_fits = [
        splines.prob(boot, Df_recur_N)
        for boot in Df_params['position_recurrence']['boots_theta']
    ]
    Df_recur_ci_up_fit = np.percentile(Df_recur_boots_fits, 95, axis=0)
    Df_recur_ci_low_fit = np.percentile(Df_recur_boots_fits, 5, axis=0)

    Df_recur_ax.plot(recur_x_vals, Df_recur_fit, color=Df_color)
    Df_recur_ax.fill_between(recur_x_vals,
                             Df_recur_ci_low_fit,
                             Df_recur_ci_up_fit,
                             facecolor=Df_color,
                             alpha=0.5)
    sns.regplot(Df_recur_data[:, 0],
                Df_recur_data[:, 1],
                ax=Df_recur_ax,
                color=Df_color,
                y_jitter=0.2,
                fit_reg=False,
                scatter_kws={'s': 1},
                marker=Df_marker)
    Df_recur_ax.set_xlim(-np.pi, np.pi)
    Df_recur_ax.set_xticks([-np.pi, -np.pi / 2, 0, np.pi / 2, np.pi])
    Df_recur_ax.set_xticklabels(['-0.50', '-0.25', '0', '0.25', '0.50'])
    Df_recur_ax.set_ylim(-0.3, 1.3)
    Df_recur_ax.set_yticks([0, 0.5, 1])
    Df_recur_ax.tick_params(length=3, pad=1, top=False)
    Df_recur_ax.set_xlabel('Initial distance from reward (fraction of belt)')
    Df_recur_ax.set_ylabel('Place cell recurrence probability')
    Df_recur_ax.set_title('')
    Df_recur_ax_2 = Df_recur_ax.twinx()
    Df_recur_ax_2.set_ylim(-0.3, 1.3)
    Df_recur_ax_2.set_yticks([0, 1])
    Df_recur_ax_2.set_yticklabels(['non-recur', 'recur'])
    Df_recur_ax_2.tick_params(length=3, pad=1, top=False)

    #
    # Place field stability
    #

    shift_x_vals = np.linspace(-np.pi, np.pi, 1000)

    Df_shift_knots = Df_params['position_stability']['all_pairs']['knots']
    Df_shift_spline = splines.CyclicSpline(Df_shift_knots)
    Df_shift_N = Df_shift_spline.design_matrix(shift_x_vals)
    Df_shift_theta_b = Df_params['position_stability']['all_pairs']['theta_b']
    Df_shift_b_fit = np.dot(Df_shift_N, Df_shift_theta_b)
    Df_shift_theta_k = Df_params['position_stability']['all_pairs']['theta_k']
    Df_shift_k_fit = splines.get_k(Df_shift_theta_k, Df_shift_N)
    Df_shift_fit_var = 1. / Df_shift_k_fit

    Df_shift_data = emd.paired_activity_centroid_distance_to_reward(Df_data)
    Df_shift_data = Df_shift_data.dropna()
    Df_shifts = Df_shift_data['second'] - Df_shift_data['first']
    Df_shifts[Df_shifts < -np.pi] += 2 * np.pi
    Df_shifts[Df_shifts >= np.pi] -= 2 * np.pi

    Df_shift_ax.plot(shift_x_vals, Df_shift_b_fit, color=Df_color)
    Df_shift_ax.fill_between(shift_x_vals,
                             Df_shift_b_fit - Df_shift_fit_var,
                             Df_shift_b_fit + Df_shift_fit_var,
                             facecolor=Df_color,
                             alpha=0.5)
    sns.regplot(Df_shift_data['first'],
                Df_shifts,
                ax=Df_shift_ax,
                color=Df_color,
                fit_reg=False,
                scatter_kws={'s': 1},
                marker=Df_marker)

    Df_shift_ax.axvline(ls='--', color='0.4', lw=0.5)
    Df_shift_ax.axhline(ls='--', color='0.4', lw=0.5)
    Df_shift_ax.plot([-np.pi, np.pi], [np.pi, -np.pi], color='g', ls=':', lw=2)
    Df_shift_ax.tick_params(length=3, pad=1, top=False)
    Df_shift_ax.set_xlabel('Initial distance from reward (fraction of belt)')
    Df_shift_ax.set_ylabel(r'$\Delta$ position (fraction of belt)')
    Df_shift_ax.set_xlim(-np.pi, np.pi)
    Df_shift_ax.set_xticks([-np.pi, -np.pi / 2, 0, np.pi / 2, np.pi])
    Df_shift_ax.set_xticklabels(['-0.50', '-0.25', '0', '0.25', '0.50'])
    Df_shift_ax.set_ylim(-np.pi, np.pi)
    Df_shift_ax.set_yticks([-np.pi, -np.pi / 2, 0, np.pi / 2, np.pi])
    Df_shift_ax.set_yticklabels(['-0.50', '-0.25', '0', '0.25', '0.50'])
    Df_shift_ax.set_title('')

    #
    # Stability by distance to reward
    #

    shift_x_vals = np.linspace(-np.pi, np.pi, 1000)

    Df_shift_knots = Df_params['position_stability']['all_pairs']['knots']
    Df_shift_spline = splines.CyclicSpline(Df_shift_knots)
    Df_shift_N = Df_shift_spline.design_matrix(shift_x_vals)

    Df_shift_theta_b = Df_params['position_stability']['all_pairs']['theta_b']
    Df_shift_b_fit = np.dot(Df_shift_N, Df_shift_theta_b)
    Df_shift_boots_b_fit = [
        np.dot(Df_shift_N, boot) for boot in Df_params['position_stability']
        ['all_pairs']['boots_theta_b']
    ]
    Df_shift_b_ci_up_fit = np.percentile(Df_shift_boots_b_fit, 95, axis=0)
    Df_shift_b_ci_low_fit = np.percentile(Df_shift_boots_b_fit, 5, axis=0)

    shift_compare_ax.plot(shift_x_vals, Df_shift_b_fit, color=Df_color)
    shift_compare_ax.fill_between(shift_x_vals,
                                  Df_shift_b_ci_low_fit,
                                  Df_shift_b_ci_up_fit,
                                  facecolor=Df_color,
                                  alpha=0.5)

    shift_compare_ax.axvline(ls='--', color='0.4', lw=0.5)
    shift_compare_ax.axhline(ls='--', color='0.4', lw=0.5)
    shift_compare_ax.tick_params(length=3, pad=1, top=False)
    shift_compare_ax.set_xlim(-np.pi, np.pi)
    shift_compare_ax.set_xticks([-np.pi, -np.pi / 2, 0, np.pi / 2, np.pi])
    shift_compare_ax.set_xticklabels(['-0.50', '-0.25', '0', '0.25', '0.50'])
    shift_compare_ax.set_ylim(-0.10 * 2 * np.pi, 0.10 * 2 * np.pi)
    y_ticks = np.array(['-0.10', '-0.05', '0', '0.05', '0.10'])
    shift_compare_ax.set_yticks(y_ticks.astype('float') * 2 * np.pi)
    shift_compare_ax.set_yticklabels(y_ticks)
    shift_compare_ax.set_xlabel(
        'Initial distance from reward (fraction of belt)')
    shift_compare_ax.set_ylabel(r'$\Delta$ position (fraction of belt)')

    Df_shift_theta_k = Df_params['position_stability']['all_pairs']['theta_k']
    Df_shift_k_fit = splines.get_k(Df_shift_theta_k, Df_shift_N)
    Df_shift_boots_k_fit = [
        splines.get_k(boot, Df_shift_N) for boot in
        Df_params['position_stability']['all_pairs']['boots_theta_k']
    ]
    Df_shift_k_ci_up_fit = np.percentile(Df_shift_boots_k_fit, 95, axis=0)
    Df_shift_k_ci_low_fit = np.percentile(Df_shift_boots_k_fit, 5, axis=0)

    var_compare_ax.plot(shift_x_vals, 1. / Df_shift_k_fit, color=Df_color)
    var_compare_ax.fill_between(shift_x_vals,
                                1. / Df_shift_k_ci_low_fit,
                                1. / Df_shift_k_ci_up_fit,
                                facecolor=Df_color,
                                alpha=0.5)

    var_compare_ax.axvline(ls='--', color='0.4', lw=0.5)
    var_compare_ax.tick_params(length=3, pad=1, top=False)
    var_compare_ax.set_xlim(-np.pi, np.pi)
    var_compare_ax.set_xticks([-np.pi, -np.pi / 2, 0, np.pi / 2, np.pi])
    var_compare_ax.set_xticklabels(['-0.50', '-0.25', '0', '0.25', '0.50'])
    y_ticks = np.array(['0.005', '0.010', '0.015', '0.020'])
    var_compare_ax.set_yticks(y_ticks.astype('float') * (2 * np.pi)**2)
    var_compare_ax.set_yticklabels(y_ticks)
    var_compare_ax.set_xlabel(
        'Initial distance from reward (fraction of belt)')
    var_compare_ax.set_ylabel(r'$\Delta$ position variance')

    #
    # Enrichment
    #

    m = pickle.load(open(simulations_path))
    Df_enrich = emp.calc_enrichment(m['Df_no_swap_pos'], m['Df_no_swap_masks'])

    emp.plot_enrichment(enrichment_ax,
                        Df_enrich,
                        Df_color,
                        title='',
                        rad=False)
    enrichment_ax.set_xlabel("Iteration ('session' #)")

    #
    # Final Enrichment
    #

    Df_no_swap_final_dist = emp.calc_final_distributions(
        m['Df_no_swap_pos'], m['Df_no_swap_masks'])

    emp.plot_final_distributions(final_enrichment_ax, [Df_no_swap_final_dist],
                                 [Df_color],
                                 title='',
                                 rad=False)

    misc.save_figure(fig, filename, save_dir=save_dir)

    plt.close('all')