def main():
    WT_params = pkl.load(open(params_path))

    WT_model = em.EnrichmentModel2(**WT_params)
    recur_model = emt.EnrichmentModel2_recur(kappa=1,
                                             span=0.8,
                                             mean_recur=0.4,
                                             **WT_params)
    offset_model = emt.EnrichmentModel2_offset(alpha=0.25, **WT_params)
    var_model = emt.EnrichmentModel2_var(kappa=1,
                                         alpha=10,
                                         mean_k=3,
                                         **WT_params)

    WT_model.initialize(n_cells=1000, flat_tol=1e-6)
    recur_model.initialize_like(WT_model)
    offset_model.initialize_like(WT_model)
    var_model.initialize_like(WT_model)
    initial_mask = WT_model.mask
    initial_positions = WT_model.positions

    recur_masks, recur_positions = [], []
    offset_masks, offset_positions = [], []
    var_masks, var_positions = [], []

    n_runs = 100

    color = sns.xkcd_rgb['forest green']

    for _ in range(n_runs):
        recur_model.initialize(initial_mask=initial_mask,
                               initial_positions=initial_positions)
        offset_model.initialize(initial_mask=initial_mask,
                                initial_positions=initial_positions)
        var_model.initialize(initial_mask=initial_mask,
                             initial_positions=initial_positions)

        recur_model.run(8)
        offset_model.run(8)
        var_model.run(8)

        recur_masks.append(recur_model._masks)
        recur_positions.append(recur_model._positions)

        offset_masks.append(offset_model._masks)
        offset_positions.append(offset_model._positions)

        var_masks.append(var_model._masks)
        var_positions.append(var_model._positions)

    recur_enrich = emp.calc_enrichment(recur_positions, recur_masks)
    offset_enrich = emp.calc_enrichment(offset_positions, offset_masks)
    var_enrich = emp.calc_enrichment(var_positions, var_masks)

    fig, axs = plt.subplots(4,
                            3,
                            figsize=(8.5, 11),
                            gridspec_kw={
                                'bottom': 0.3,
                                'wspace': 0.4
                            })

    show_parameters(axs[:, 0], recur_model, recur_enrich, color=color)
    show_parameters(axs[:, 1], offset_model, offset_enrich, color=color)
    show_parameters(axs[:, 2], var_model, var_enrich, color=color)

    for ax in axs[:, 1:].flat:
        ax.set_ylabel('')
    for ax in axs[:2, :].flat:
        ax.set_xlabel('')
    for ax in axs[:, 0]:
        ax.set_xlabel('')
    for ax in axs[:, 2]:
        ax.set_xlabel('')

    axs[0, 0].set_title('Stable recurrence')
    axs[0, 1].set_title('Shift towards reward')
    axs[0, 2].set_title('Stable position')

    save_figure(fig, filename, save_dir=save_dir)
def main():

    fig, axs = plt.subplots(
        5, 2, figsize=(8.5, 11), gridspec_kw={
            'wspace': 0.4, 'hspace': 0.35, 'right': 0.75, 'top': 0.9,
            'bottom': 0.1})

    #
    # Run the model
    #

    n_cells = 1000
    n_runs = 100
    tol = 1e-4

    WT_params = pkl.load(open(WT_params_path, 'r'))

    model_cls = em.EnrichmentModel2

    WT_model = model_cls(**WT_params)
    flat_model = WT_model.copy()
    flat_model.flatten()

    WT_model.initialize(n_cells=n_cells, flat_tol=tol)
    flat_model.initialize_like(WT_model)
    initial_mask = WT_model.mask
    initial_positions = WT_model.positions

    def run_model(model1, model2=None, initial_positions=None,
                  initial_mask=None, n_runs=100, **interp_kwargs):
        masks, positions = [], []

        model = model1.copy()

        if model2 is not None:
            model.interpolate(model2, **interp_kwargs)

        for _ in range(n_runs):
            model.initialize(
                initial_mask=initial_mask, initial_positions=initial_positions)

            model.run(8)

            masks.append(model._masks)
            positions.append(model._positions)

        return positions, masks

    WT_no_swap_pos, WT_no_swap_masks = run_model(
        WT_model, initial_positions=initial_positions,
        initial_mask=initial_mask, n_runs=n_runs)
    flat_no_swap_pos, flat_no_swap_masks = run_model(
        flat_model, initial_positions=initial_positions,
        initial_mask=initial_mask, n_runs=n_runs)

    WT_swap_on_pos, WT_swap_on_masks = run_model(
        WT_model, flat_model, initial_positions=initial_positions,
        initial_mask=initial_mask, n_runs=n_runs, on=1)
    flat_swap_on_pos, flat_swap_on_masks = run_model(
        flat_model, WT_model, initial_positions=initial_positions,
        initial_mask=initial_mask, n_runs=n_runs, on=1)

    WT_swap_recur_pos, WT_swap_recur_masks = run_model(
        WT_model, flat_model, initial_positions=initial_positions,
        initial_mask=initial_mask, n_runs=n_runs, recur=1)
    flat_swap_recur_pos, flat_swap_recur_masks = run_model(
        flat_model, WT_model, initial_positions=initial_positions,
        initial_mask=initial_mask, n_runs=n_runs, recur=1)

    WT_swap_shift_b_pos, WT_swap_shift_b_masks = run_model(
        WT_model, flat_model, initial_positions=initial_positions,
        initial_mask=initial_mask, n_runs=n_runs, shift_b=1)
    flat_swap_shift_b_pos, flat_swap_shift_b_masks = run_model(
        flat_model, WT_model, initial_positions=initial_positions,
        initial_mask=initial_mask, n_runs=n_runs, shift_b=1)

    WT_swap_shift_k_pos, WT_swap_shift_k_masks = run_model(
        WT_model, flat_model, initial_positions=initial_positions,
        initial_mask=initial_mask, n_runs=n_runs, shift_k=1)
    flat_swap_shift_k_pos, flat_swap_shift_k_masks = run_model(
        flat_model, WT_model, initial_positions=initial_positions,
        initial_mask=initial_mask, n_runs=n_runs, shift_k=1)

    #
    # Distance to reward
    #

    WT_no_swap_enrich = emp.calc_enrichment(
        WT_no_swap_pos, WT_no_swap_masks)
    flat_no_swap_enrich = emp.calc_enrichment(
        flat_no_swap_pos, flat_no_swap_masks)

    WT_swap_recur_enrich = emp.calc_enrichment(
        WT_swap_recur_pos, WT_swap_recur_masks)
    flat_swap_recur_enrich = emp.calc_enrichment(
        flat_swap_recur_pos, flat_swap_recur_masks)

    WT_swap_shift_b_enrich = emp.calc_enrichment(
        WT_swap_shift_b_pos, WT_swap_shift_b_masks)
    flat_swap_shift_b_enrich = emp.calc_enrichment(
        flat_swap_shift_b_pos, flat_swap_shift_b_masks)

    WT_swap_shift_k_enrich = emp.calc_enrichment(
        WT_swap_shift_k_pos, WT_swap_shift_k_masks)
    flat_swap_shift_k_enrich = emp.calc_enrichment(
        flat_swap_shift_k_pos, flat_swap_shift_k_masks)

    emp.plot_enrichment(
        axs[0, 0], WT_no_swap_enrich, WT_color, '', rad=False)
    emp.plot_enrichment(
        axs[0, 0], flat_no_swap_enrich, flat_color, '', rad=False)
    plotting.right_label(axs[0, 1], 'No swap')
    axs[0, 0].set_ylabel('')

    emp.plot_enrichment(
        axs[1, 0], WT_swap_recur_enrich, WT_color, '', rad=False)
    emp.plot_enrichment(
        axs[1, 0], flat_swap_recur_enrich, flat_color, '', rad=False)
    plotting.right_label(axs[1, 1], 'Swap recurrence')

    emp.plot_enrichment(
        axs[2, 0], WT_swap_shift_k_enrich, WT_color, '', rad=False)
    emp.plot_enrichment(
        axs[2, 0], flat_swap_shift_k_enrich, flat_color, '', rad=False)
    plotting.right_label(axs[2, 1], 'Swap shift variance')
    axs[2, 0].set_ylabel('')

    emp.plot_enrichment(
        axs[3, 0], WT_swap_shift_b_enrich, WT_color, '', rad=False)
    emp.plot_enrichment(
        axs[3, 0], flat_swap_shift_b_enrich, flat_color, '', rad=False)
    plotting.right_label(axs[3, 1], 'Swap shift offset')
    axs[3, 0].set_ylabel('')

    plotting.stackedText(
        axs[0, 0], [WT_label, flat_label], colors=colors, loc=2, size=10)

    #
    # Final distribution
    #

    WT_no_swap_final_dist = emp.calc_final_distributions(
        WT_no_swap_pos, WT_no_swap_masks)
    flat_no_swap_final_dist = emp.calc_final_distributions(
        flat_no_swap_pos, flat_no_swap_masks)

    WT_swap_recur_final_dist = emp.calc_final_distributions(
        WT_swap_recur_pos, WT_swap_recur_masks)
    flat_swap_recur_final_dist = emp.calc_final_distributions(
        flat_swap_recur_pos, flat_swap_recur_masks)

    WT_swap_shift_b_final_dist = emp.calc_final_distributions(
        WT_swap_shift_b_pos, WT_swap_shift_b_masks)
    flat_swap_shift_b_final_dist = emp.calc_final_distributions(
        flat_swap_shift_b_pos, flat_swap_shift_b_masks)

    WT_swap_shift_k_final_dist = emp.calc_final_distributions(
        WT_swap_shift_k_pos, WT_swap_shift_k_masks)
    flat_swap_shift_k_final_dist = emp.calc_final_distributions(
        flat_swap_shift_k_pos, flat_swap_shift_k_masks)

    emp.plot_final_distributions(
        axs[0, 1], [WT_no_swap_final_dist, flat_no_swap_final_dist], colors,
        labels=None, title='', rad=False)
    emp.plot_final_distributions(
        axs[1, 1], [WT_swap_recur_final_dist, flat_swap_recur_final_dist],
        colors, labels=None, title='', rad=False)
    emp.plot_final_distributions(
        axs[2, 1], [WT_swap_shift_k_final_dist, flat_swap_shift_k_final_dist],
        colors, labels=None, title='', rad=False)
    emp.plot_final_distributions(
        axs[3, 1], [WT_swap_shift_b_final_dist, flat_swap_shift_b_final_dist],
        colors, labels=None, title='', rad=False)

    for ax in axs[:3, :].flat:
        ax.set_xlabel('')
    for ax in axs[4]:
        ax.set_visible(False)

    misc.save_figure(fig, filename, save_dir=save_dir)

    plt.close('all')
def main():

    fig = plt.figure(figsize=(8.5, 11))

    gs = plt.GridSpec(
        3, 2, left=0.1, bottom=0.4, right=0.5, top=0.9, hspace=0.4, wspace=0.4)
    shift_schematic_ax = fig.add_subplot(gs[0, 1])
    model_enrich_by_time_ax = fig.add_subplot(gs[1, 0])
    model_final_enrich_ax = fig.add_subplot(gs[1, 1])
    model_swap_compare_enrich_ax = fig.add_subplot(gs[2, 0])
    model_WT_swap_final_enrich_ax = fig.add_subplot(gs[2, 1])

    #
    # Shift schematic
    #

    mu = 1
    k = 2.
    pos1 = 1.5

    def vms(x, mu=1, k=2.):
        return np.exp(k * np.cos(x - mu)) / (2 * np.pi * i0(k))

    xr = np.linspace(-np.pi, np.pi, 1000)

    shift_schematic_ax.plot(xr, [vms(x) for x in xr], color='k')
    shift_schematic_ax.axvline(pos1, color='r', ls='--')
    shift_schematic_ax.plot(
        [mu, mu], [0, vms(mu, mu=mu, k=k)], ls=':', color='0.3')
    sns.despine(ax=shift_schematic_ax, top=True, right=True)
    shift_schematic_ax.tick_params(labelleft=True, left=True, direction='out')
    shift_schematic_ax.set_xticks([-np.pi, -np.pi / 2, 0, np.pi / 2, np.pi])
    shift_schematic_ax.set_xticklabels(['-0.50', '-0.25', '0', '0.25', '0.50'])
    shift_schematic_ax.set_xlim(-np.pi, np.pi)
    shift_schematic_ax.set_yticks([0, 0.2, 0.4, 0.6])
    shift_schematic_ax.set_ylim(0, 0.6)
    shift_schematic_ax.set_xlabel('Distance from reward (fraction of belt)')
    shift_schematic_ax.set_ylabel(
        'New place field centroid\nprobability density')

    #
    # Distance to reward
    #

    m = pickle.load(open(simulations_path))

    WT_enrich = emp.calc_enrichment(m['WT_no_swap_pos'], m['WT_no_swap_masks'])
    flat_enrich = emp.calc_enrichment(m['Df_no_swap_pos'], m['Df_no_swap_masks'])

    emp.plot_enrichment(
        model_enrich_by_time_ax, WT_enrich, WT_color, title='', rad=False)
    emp.plot_enrichment(
        model_enrich_by_time_ax, flat_enrich, flat_color, title='', rad=False)

    model_enrich_by_time_ax.set_xlabel("Iteration ('session' #)")
    plotting.stackedText(
        model_enrich_by_time_ax, [WT_label, flat_label], colors=colors, loc=2,
        size=10)

    #
    # Calc final distributions
    #

    WT_no_swap_final_dist = emp.calc_final_distributions(
        m['WT_no_swap_pos'], m['WT_no_swap_masks'])
    flat_no_swap_final_dist = emp.calc_final_distributions(
        m['Df_no_swap_pos'], m['Df_no_swap_masks'])

    WT_swap_recur_final_dist = emp.calc_final_distributions(
        m['WT_swap_recur_pos'], m['WT_swap_recur_masks'])
    flat_swap_recur_final_dist = emp.calc_final_distributions(
        m['Df_swap_recur_pos'], m['Df_swap_recur_masks'])

    WT_swap_shift_b_final_dist = emp.calc_final_distributions(
        m['WT_swap_shift_b_pos'], m['WT_swap_shift_b_masks'])
    flat_swap_shift_b_final_dist = emp.calc_final_distributions(
        m['Df_swap_shift_b_pos'], m['Df_swap_shift_b_masks'])

    WT_swap_shift_k_final_dist = emp.calc_final_distributions(
        m['WT_swap_shift_k_pos'], m['WT_swap_shift_k_masks'])
    flat_swap_shift_k_final_dist = emp.calc_final_distributions(
        m['Df_swap_shift_k_pos'], m['Df_swap_shift_k_masks'])

    #
    # Final distribution
    #

    emp.plot_final_distributions(
        model_final_enrich_ax,
        [WT_no_swap_final_dist, flat_no_swap_final_dist],
        colors, labels=labels, title='', rad=False)

    plotting.stackedText(
        model_final_enrich_ax, [WT_label, flat_label], colors=colors, loc=2,
        size=10)
    #
    # Compare enrichment
    #

    WT_bars = [
        WT_no_swap_final_dist, WT_swap_recur_final_dist,
        WT_swap_shift_k_final_dist, WT_swap_shift_b_final_dist]
    flat_bars = [
        flat_no_swap_final_dist, flat_swap_recur_final_dist,
        flat_swap_shift_k_final_dist, flat_swap_shift_b_final_dist]

    WT_bars = [np.pi / 2 - np.abs(bar) for bar in WT_bars]
    flat_bars = [np.pi / 2 - np.abs(bar) for bar in flat_bars]

    plotting.grouped_bar(
        model_swap_compare_enrich_ax, [WT_bars, flat_bars],
        [WT_label, flat_label],
        ['No\nswap', 'Swap\n' + r'$P_{recur}$',
         'Swap\nvariance', 'Swap\noffset'], [WT_color, flat_color],
        error_bars=None)

    sns.despine(ax=model_swap_compare_enrich_ax)
    model_swap_compare_enrich_ax.tick_params(length=3, pad=2, direction='out')
    model_swap_compare_enrich_ax.set_ylabel('Final enrichment (rad)')
    model_swap_compare_enrich_ax.get_legend().set_visible(False)
    model_swap_compare_enrich_ax.set_ylim(0, 0.08 * 2 * np.pi)
    y_ticks = np.array(['0', '0.02', '0.04', '0.06', '0.08'])
    model_swap_compare_enrich_ax.set_yticks(
        y_ticks.astype('float') * 2 * np.pi)
    model_swap_compare_enrich_ax.set_yticklabels(y_ticks)
    plotting.stackedText(
        model_swap_compare_enrich_ax, [WT_label, flat_label], colors=colors,
        loc=2, size=10)

    #
    # WT swap final distributions
    #

    hist_colors = iter(sns.color_palette())

    emp.plot_final_distributions(
        model_WT_swap_final_enrich_ax,
        [WT_no_swap_final_dist,
         WT_swap_recur_final_dist, WT_swap_shift_k_final_dist,
         WT_swap_shift_b_final_dist], colors=hist_colors,
        labels=['No swap', r'Swap $P_{recur}$',
                'Swap shift variance', 'Swap shift offset'], rad=False)
    model_WT_swap_final_enrich_ax.legend(
        loc='lower right', fontsize=5, frameon=False)

    misc.save_figure(fig, filename, save_dir=save_dir)

    plt.close('all')
Exemplo n.º 4
0
def main():
    Df_raw_data, Df_data = emd.load_data('df',
                                         root=os.path.join(
                                             df.data_path, 'enrichment_model'))

    Df_params = pickle.load(open(params_path))

    fig = plt.figure(figsize=(8.5, 11))
    gs = plt.GridSpec(3,
                      2,
                      left=0.1,
                      bottom=0.5,
                      right=0.5,
                      top=0.9,
                      wspace=0.3,
                      hspace=0.3)
    Df_recur_ax = fig.add_subplot(gs[0, 0])
    Df_shift_ax = fig.add_subplot(gs[0, 1])
    shift_compare_ax = fig.add_subplot(gs[1, 0])
    var_compare_ax = fig.add_subplot(gs[1, 1])
    enrichment_ax = fig.add_subplot(gs[2, 0])
    final_enrichment_ax = fig.add_subplot(gs[2, 1])

    #
    # Recurrence by position
    #

    recur_x_vals = np.linspace(-np.pi, np.pi, 1000)

    Df_recur_data = emd.recurrence_by_position(Df_data, method='cv')
    Df_recur_knots = np.linspace(-np.pi, np.pi,
                                 Df_params['position_recurrence']['n_knots'])
    Df_recur_spline = splines.CyclicSpline(Df_recur_knots)
    Df_recur_N = Df_recur_spline.design_matrix(recur_x_vals)

    Df_recur_fit = splines.prob(Df_params['position_recurrence']['theta'],
                                Df_recur_N)

    Df_recur_boots_fits = [
        splines.prob(boot, Df_recur_N)
        for boot in Df_params['position_recurrence']['boots_theta']
    ]
    Df_recur_ci_up_fit = np.percentile(Df_recur_boots_fits, 95, axis=0)
    Df_recur_ci_low_fit = np.percentile(Df_recur_boots_fits, 5, axis=0)

    Df_recur_ax.plot(recur_x_vals, Df_recur_fit, color=Df_color)
    Df_recur_ax.fill_between(recur_x_vals,
                             Df_recur_ci_low_fit,
                             Df_recur_ci_up_fit,
                             facecolor=Df_color,
                             alpha=0.5)
    sns.regplot(Df_recur_data[:, 0],
                Df_recur_data[:, 1],
                ax=Df_recur_ax,
                color=Df_color,
                y_jitter=0.2,
                fit_reg=False,
                scatter_kws={'s': 1},
                marker=Df_marker)
    Df_recur_ax.set_xlim(-np.pi, np.pi)
    Df_recur_ax.set_xticks([-np.pi, -np.pi / 2, 0, np.pi / 2, np.pi])
    Df_recur_ax.set_xticklabels(['-0.50', '-0.25', '0', '0.25', '0.50'])
    Df_recur_ax.set_ylim(-0.3, 1.3)
    Df_recur_ax.set_yticks([0, 0.5, 1])
    Df_recur_ax.tick_params(length=3, pad=1, top=False)
    Df_recur_ax.set_xlabel('Initial distance from reward (fraction of belt)')
    Df_recur_ax.set_ylabel('Place cell recurrence probability')
    Df_recur_ax.set_title('')
    Df_recur_ax_2 = Df_recur_ax.twinx()
    Df_recur_ax_2.set_ylim(-0.3, 1.3)
    Df_recur_ax_2.set_yticks([0, 1])
    Df_recur_ax_2.set_yticklabels(['non-recur', 'recur'])
    Df_recur_ax_2.tick_params(length=3, pad=1, top=False)

    #
    # Place field stability
    #

    shift_x_vals = np.linspace(-np.pi, np.pi, 1000)

    Df_shift_knots = Df_params['position_stability']['all_pairs']['knots']
    Df_shift_spline = splines.CyclicSpline(Df_shift_knots)
    Df_shift_N = Df_shift_spline.design_matrix(shift_x_vals)
    Df_shift_theta_b = Df_params['position_stability']['all_pairs']['theta_b']
    Df_shift_b_fit = np.dot(Df_shift_N, Df_shift_theta_b)
    Df_shift_theta_k = Df_params['position_stability']['all_pairs']['theta_k']
    Df_shift_k_fit = splines.get_k(Df_shift_theta_k, Df_shift_N)
    Df_shift_fit_var = 1. / Df_shift_k_fit

    Df_shift_data = emd.paired_activity_centroid_distance_to_reward(Df_data)
    Df_shift_data = Df_shift_data.dropna()
    Df_shifts = Df_shift_data['second'] - Df_shift_data['first']
    Df_shifts[Df_shifts < -np.pi] += 2 * np.pi
    Df_shifts[Df_shifts >= np.pi] -= 2 * np.pi

    Df_shift_ax.plot(shift_x_vals, Df_shift_b_fit, color=Df_color)
    Df_shift_ax.fill_between(shift_x_vals,
                             Df_shift_b_fit - Df_shift_fit_var,
                             Df_shift_b_fit + Df_shift_fit_var,
                             facecolor=Df_color,
                             alpha=0.5)
    sns.regplot(Df_shift_data['first'],
                Df_shifts,
                ax=Df_shift_ax,
                color=Df_color,
                fit_reg=False,
                scatter_kws={'s': 1},
                marker=Df_marker)

    Df_shift_ax.axvline(ls='--', color='0.4', lw=0.5)
    Df_shift_ax.axhline(ls='--', color='0.4', lw=0.5)
    Df_shift_ax.plot([-np.pi, np.pi], [np.pi, -np.pi], color='g', ls=':', lw=2)
    Df_shift_ax.tick_params(length=3, pad=1, top=False)
    Df_shift_ax.set_xlabel('Initial distance from reward (fraction of belt)')
    Df_shift_ax.set_ylabel(r'$\Delta$ position (fraction of belt)')
    Df_shift_ax.set_xlim(-np.pi, np.pi)
    Df_shift_ax.set_xticks([-np.pi, -np.pi / 2, 0, np.pi / 2, np.pi])
    Df_shift_ax.set_xticklabels(['-0.50', '-0.25', '0', '0.25', '0.50'])
    Df_shift_ax.set_ylim(-np.pi, np.pi)
    Df_shift_ax.set_yticks([-np.pi, -np.pi / 2, 0, np.pi / 2, np.pi])
    Df_shift_ax.set_yticklabels(['-0.50', '-0.25', '0', '0.25', '0.50'])
    Df_shift_ax.set_title('')

    #
    # Stability by distance to reward
    #

    shift_x_vals = np.linspace(-np.pi, np.pi, 1000)

    Df_shift_knots = Df_params['position_stability']['all_pairs']['knots']
    Df_shift_spline = splines.CyclicSpline(Df_shift_knots)
    Df_shift_N = Df_shift_spline.design_matrix(shift_x_vals)

    Df_shift_theta_b = Df_params['position_stability']['all_pairs']['theta_b']
    Df_shift_b_fit = np.dot(Df_shift_N, Df_shift_theta_b)
    Df_shift_boots_b_fit = [
        np.dot(Df_shift_N, boot) for boot in Df_params['position_stability']
        ['all_pairs']['boots_theta_b']
    ]
    Df_shift_b_ci_up_fit = np.percentile(Df_shift_boots_b_fit, 95, axis=0)
    Df_shift_b_ci_low_fit = np.percentile(Df_shift_boots_b_fit, 5, axis=0)

    shift_compare_ax.plot(shift_x_vals, Df_shift_b_fit, color=Df_color)
    shift_compare_ax.fill_between(shift_x_vals,
                                  Df_shift_b_ci_low_fit,
                                  Df_shift_b_ci_up_fit,
                                  facecolor=Df_color,
                                  alpha=0.5)

    shift_compare_ax.axvline(ls='--', color='0.4', lw=0.5)
    shift_compare_ax.axhline(ls='--', color='0.4', lw=0.5)
    shift_compare_ax.tick_params(length=3, pad=1, top=False)
    shift_compare_ax.set_xlim(-np.pi, np.pi)
    shift_compare_ax.set_xticks([-np.pi, -np.pi / 2, 0, np.pi / 2, np.pi])
    shift_compare_ax.set_xticklabels(['-0.50', '-0.25', '0', '0.25', '0.50'])
    shift_compare_ax.set_ylim(-0.10 * 2 * np.pi, 0.10 * 2 * np.pi)
    y_ticks = np.array(['-0.10', '-0.05', '0', '0.05', '0.10'])
    shift_compare_ax.set_yticks(y_ticks.astype('float') * 2 * np.pi)
    shift_compare_ax.set_yticklabels(y_ticks)
    shift_compare_ax.set_xlabel(
        'Initial distance from reward (fraction of belt)')
    shift_compare_ax.set_ylabel(r'$\Delta$ position (fraction of belt)')

    Df_shift_theta_k = Df_params['position_stability']['all_pairs']['theta_k']
    Df_shift_k_fit = splines.get_k(Df_shift_theta_k, Df_shift_N)
    Df_shift_boots_k_fit = [
        splines.get_k(boot, Df_shift_N) for boot in
        Df_params['position_stability']['all_pairs']['boots_theta_k']
    ]
    Df_shift_k_ci_up_fit = np.percentile(Df_shift_boots_k_fit, 95, axis=0)
    Df_shift_k_ci_low_fit = np.percentile(Df_shift_boots_k_fit, 5, axis=0)

    var_compare_ax.plot(shift_x_vals, 1. / Df_shift_k_fit, color=Df_color)
    var_compare_ax.fill_between(shift_x_vals,
                                1. / Df_shift_k_ci_low_fit,
                                1. / Df_shift_k_ci_up_fit,
                                facecolor=Df_color,
                                alpha=0.5)

    var_compare_ax.axvline(ls='--', color='0.4', lw=0.5)
    var_compare_ax.tick_params(length=3, pad=1, top=False)
    var_compare_ax.set_xlim(-np.pi, np.pi)
    var_compare_ax.set_xticks([-np.pi, -np.pi / 2, 0, np.pi / 2, np.pi])
    var_compare_ax.set_xticklabels(['-0.50', '-0.25', '0', '0.25', '0.50'])
    y_ticks = np.array(['0.005', '0.010', '0.015', '0.020'])
    var_compare_ax.set_yticks(y_ticks.astype('float') * (2 * np.pi)**2)
    var_compare_ax.set_yticklabels(y_ticks)
    var_compare_ax.set_xlabel(
        'Initial distance from reward (fraction of belt)')
    var_compare_ax.set_ylabel(r'$\Delta$ position variance')

    #
    # Enrichment
    #

    m = pickle.load(open(simulations_path))
    Df_enrich = emp.calc_enrichment(m['Df_no_swap_pos'], m['Df_no_swap_masks'])

    emp.plot_enrichment(enrichment_ax,
                        Df_enrich,
                        Df_color,
                        title='',
                        rad=False)
    enrichment_ax.set_xlabel("Iteration ('session' #)")

    #
    # Final Enrichment
    #

    Df_no_swap_final_dist = emp.calc_final_distributions(
        m['Df_no_swap_pos'], m['Df_no_swap_masks'])

    emp.plot_final_distributions(final_enrichment_ax, [Df_no_swap_final_dist],
                                 [Df_color],
                                 title='',
                                 rad=False)

    misc.save_figure(fig, filename, save_dir=save_dir)

    plt.close('all')
Exemplo n.º 5
0
def main():

    fig = plt.figure(figsize=(8.5, 11))

    gs1 = plt.GridSpec(2,
                       2,
                       left=0.1,
                       right=0.7,
                       top=0.9,
                       bottom=0.5,
                       hspace=0.4,
                       wspace=0.4)
    WT_enrich_ax = fig.add_subplot(gs1[0, 0])
    Df_enrich_ax = fig.add_subplot(gs1[0, 1])
    WT_final_dist_ax = fig.add_subplot(gs1[1, 0])
    Df_final_dist_ax = fig.add_subplot(gs1[1, 1])

    simulations_path_A = os.path.join(
        df.data_path, 'enrichment_model',
        'WT_Df_enrichment_model_simulation_A.pkl')
    simulations_path_B = os.path.join(
        df.data_path, 'enrichment_model',
        'WT_Df_enrichment_model_simulation_B.pkl')
    simulations_path_C = os.path.join(
        df.data_path, 'enrichment_model',
        'WT_Df_enrichment_model_simulation_C.pkl')

    m_A = pickle.load(open(simulations_path_A))
    m_B = pickle.load(open(simulations_path_B))
    m_C = pickle.load(open(simulations_path_C))

    WT_colors = sns.light_palette(WT_color, 7)[2::2]
    Df_colors = sns.light_palette(Df_color, 7)[2::2]

    condition_labels = [
        r'Condition $\mathrm{I}$', r'Condition $\mathrm{II}$',
        r'Condition $\mathrm{III}$'
    ]

    WT_final_dists, Df_final_dists = [], []

    for m, WT_c, Df_c in zip((m_A, m_B, m_C), WT_colors, Df_colors):

        WT_enrich = emp.calc_enrichment(m['WT_no_swap_pos'],
                                        m['WT_no_swap_masks'])
        Df_enrich = emp.calc_enrichment(m['Df_no_swap_pos'],
                                        m['Df_no_swap_masks'])

        WT_final_dists.append(
            emp.calc_final_distributions(m['WT_no_swap_pos'],
                                         m['WT_no_swap_masks']))
        Df_final_dists.append(
            emp.calc_final_distributions(m['Df_no_swap_pos'],
                                         m['Df_no_swap_masks']))

        emp.plot_enrichment(WT_enrich_ax, WT_enrich, WT_c, title='', rad=False)
        emp.plot_enrichment(Df_enrich_ax, Df_enrich, Df_c, title='', rad=False)

    WT_enrich_ax.set_xlabel("Iteration ('session' #)")
    Df_enrich_ax.set_xlabel("Iteration ('session' #)")
    plotting.stackedText(WT_enrich_ax,
                         condition_labels,
                         colors=WT_colors,
                         loc=2,
                         size=8)
    plotting.stackedText(Df_enrich_ax,
                         condition_labels,
                         colors=Df_colors,
                         loc=2,
                         size=8)

    emp.plot_final_distributions(WT_final_dist_ax,
                                 WT_final_dists,
                                 WT_colors,
                                 labels=condition_labels,
                                 title='',
                                 rad=False)
    emp.plot_final_distributions(Df_final_dist_ax,
                                 Df_final_dists,
                                 Df_colors,
                                 labels=condition_labels,
                                 title='',
                                 rad=False)

    WT_final_dist_ax.set_xlabel('Distance from reward\n(fraction of belt)')
    Df_final_dist_ax.set_xlabel('Distance from reward\n(fraction of belt)')
    plotting.stackedText(WT_final_dist_ax,
                         condition_labels,
                         colors=WT_colors,
                         loc=2,
                         size=8)
    plotting.stackedText(Df_final_dist_ax,
                         condition_labels,
                         colors=Df_colors,
                         loc=2,
                         size=8)
    WT_final_dist_ax.set_yticks([0, 0.1, 0.2, 0.3])
    Df_final_dist_ax.set_yticks([0, 0.1, 0.2, 0.3])

    save_figure(fig, filename, save_dir=save_dir)