def plot_rotation_examples(field_data, type='grid'):
    grid_cells = field_data['cell type'] == type
    if type == 'grid':
        cell_1 = 20
        cell_2 = 25
        cell_3 = 35
    else:
        cell_1 = 0
        cell_2 = 1
        cell_3 = 2
    combined_field_histograms = field_data.hd_hist_from_all_fields[field_data.accepted_field & grid_cells]
    rotated = field_data.hd_hist_from_all_fields_rotated[field_data.accepted_field & grid_cells]
    total_x = field_data.population_mean_vector_x[field_data.accepted_field & grid_cells]
    total_y = field_data.population_mean_vector_y[field_data.accepted_field & grid_cells]
    # 20, 25, 35
    plot_and_save_polar_histogram(combined_field_histograms.iloc[cell_1], type + '_cell_' + str(cell_1))
    plot_and_save_polar_histogram(combined_field_histograms.iloc[cell_2], type + '_cell_' + str(cell_2))
    plot_and_save_polar_histogram(combined_field_histograms.iloc[cell_3], type + '_cell_' + str(cell_3))

    plot_and_save_polar_histogram(rotated.iloc[cell_1], type + '_cell_' + str(cell_1) + '_rotated')
    plot_and_save_polar_histogram(rotated.iloc[cell_2], type + '_cell_' + str(cell_2) + '_rotated')
    plot_and_save_polar_histogram(rotated.iloc[cell_3], type + '_cell_' + str(cell_3) + '_rotated')

    # combine them to make one polar plot
    hd_polar_fig = plt.figure()
    ax = hd_polar_fig.add_subplot(1, 1, 1, polar=True)
    ax = plot_utility.style_polar_plot(ax)
    average_histogram = np.average([rotated.iloc[cell_1], rotated.iloc[cell_2], rotated.iloc[cell_3]], axis=0)
    theta = np.linspace(0, 2 * np.pi, 361)
    ax.plot(theta[:-1], average_histogram, color='red', linewidth=10)
    plt.savefig(analysis_path + 'rotated_hd_histograms_' + 'combined_example_hist' + '.png', dpi=300, bbox_inches="tight")
    plt.close()
def plot_rotated_histograms_for_cell_type(field_data, cell_type='grid', animal='mouse'):
    print('analyze ' + cell_type + ' cells')
    list_of_cells = field_data.unique_cell_id.unique()
    histograms = []
    for cell in list_of_cells:
        cell_type_filter = field_data['cell type'] == cell_type  # filter for cell type
        accepted = field_data['accepted_field'] == True
        fields_of_cell = field_data.unique_cell_id == cell  # filter for cell
        if not field_data[fields_of_cell & cell_type_filter & accepted].empty:
            histogram = field_data[fields_of_cell & cell_type_filter & accepted].hd_hist_from_all_fields_rotated.iloc[0]
            histograms.append(histogram)

    plt.cla()
    hd_polar_fig = plt.figure()
    ax = hd_polar_fig.add_subplot(1, 1, 1)
    print('Number of ' + cell_type + ' cells: ' + str(len(histograms)))
    histograms_to_plot = []
    for histogram in histograms:
        if not np.isnan(histogram).any():
            theta = np.linspace(0, 2 * np.pi, 361)
            ax = plt.subplot(1, 1, 1, polar=True)
            ax = plot_utility.style_polar_plot(ax)
            ax.plot(theta[:-1], histogram, color='gray', linewidth=2, alpha=70)
            histograms_to_plot.append(histogram)
    # combine them to make one polar plot
    average_histogram = np.average(histograms_to_plot, axis=0)
    theta = np.linspace(0, 2 * np.pi, 361)
    ax.plot(theta[:-1], average_histogram, color='red', linewidth=10)
    plt.savefig(analysis_path + animal + '_rotated_hd_histograms_' + cell_type + '.png', dpi=300, bbox_inches="tight")
    plt.close()
def plot_bar_chart_for_cells_percentile_error_bar(spatial_firing, path, animal, shuffle_type='occupancy'):
    counter = 0
    for index, cell in spatial_firing.iterrows():
        mean = np.append(cell['shuffled_means'], cell['shuffled_means'][0])
        percentile_95 = np.append(cell['error_bar_95'], cell['error_bar_95'][0])
        percentile_5 = np.append(cell['error_bar_5'], cell['error_bar_5'][0])
        shuffled_histograms_hz = cell['shuffled_histograms_hz']
        max_rate = np.round(cell.hd_histogram_real_data_hz.max(), 2)
        x_pos = np.linspace(0, 2*np.pi, shuffled_histograms_hz.shape[1] + 1)

        significant_bins_to_mark = np.where(cell.p_values_corrected_bars_bh < 0.05)  # indices
        significant_bins_to_mark = x_pos[significant_bins_to_mark[0]]
        y_value_markers = [max_rate + 0.5] * len(significant_bins_to_mark)

        ax = plt.subplot(1, 1, 1, polar=True)
        ax = plot_utility.style_polar_plot(ax)
        x_labels = ["0", "", "", "", "", "90", "", "", "", "", "180", "", "", "", "", "270", "", "", "", ""]
        plt.xticks(x_pos, x_labels)
        ax.fill_between(x_pos, mean - percentile_5, percentile_95 + mean, color='grey', alpha=0.4)
        ax.plot(x_pos, mean, color='grey', linewidth=5, alpha=0.7)
        observed_data = np.append(cell.hd_histogram_real_data_hz, cell.hd_histogram_real_data_hz[0])
        ax.plot(x_pos, observed_data, color='navy', linewidth=5)
        plt.title('\n' + str(max_rate) + ' Hz', fontsize=20, y=1.08)
        if (cell.p_values_corrected_bars_bh < 0.05).sum() > 0:
            ax.scatter(significant_bins_to_mark, y_value_markers, c='red',  marker='*', zorder=3, s=100)
        plt.subplots_adjust(top=0.85)
        plt.savefig(analysis_path + animal + '_' + shuffle_type + '/' + str(counter) + str(cell['session_id']) + str(cell['cluster_id']) + '_percentile_polar_' + str(cell.percentile_value) + '.png')
        plt.close()
        counter += 1
def plot_and_save_polar_histogram(histogram, name, color='gray', line_width=10):
    plt.cla()
    theta = np.linspace(0, 2 * np.pi, 361)
    ax = plt.subplot(1, 1, 1, polar=True)
    ax = plot_utility.style_polar_plot(ax)
    ax.plot(theta[:-1], histogram, color=color, linewidth=line_width)
    plt.savefig(analysis_path + 'rotated_hd_histograms_' + name + '.png', dpi=300, bbox_inches="tight")
    plt.close()
예제 #5
0
def plot_polar_hd_hist(hist_1, hist_2, cluster, save_path, color1='lime', color2='navy', title=''):
    hd_polar_fig = plt.figure()
    hd_polar_fig.set_size_inches(5, 5, forward=True)
    ax = hd_polar_fig.add_subplot(1, 1, 1)  # specify (nrows, ncols, axnum)
    theta = np.linspace(0, 2*np.pi, 361)  # x axis
    ax = plt.subplot(1, 1, 1, polar=True)
    ax = plot_utility.style_polar_plot(ax)
    ax.plot(theta[:-1], hist_1, color=color1, linewidth=2)
    ax.plot(theta[:-1], hist_2, color=color2, linewidth=2)
    plt.title(title)
    # ax.plot(theta[:-1], hist_2 * (max(hist_1) / max(hist_2)), color='navy', linewidth=2)
    plt.tight_layout()
    plt.savefig(save_path + '_hd_polar_' + str(cluster + 1) + '.png', dpi=300, bbox_inches="tight")
    # plt.savefig(save_path + '_hd_polar_' + str(cluster + 1) + '.pdf', bbox_inches="tight")
    plt.close()
def save_field_polar_plot(save_path, hd_hist_session, hd_hist_cluster, cluster, spatial_firing, colors, field_id, name):
    field_polar = plt.figure()
    field_polar.set_size_inches(5, 5, forward=True)
    theta = np.linspace(0, 2*np.pi, 361)  # x axis
    hd_plot_field = field_polar.add_subplot(1, 1, 1, polar=True)
    hd_plot_field = plot_utility.style_polar_plot(hd_plot_field)

    hd_plot_field.plot(theta[:-1], hd_hist_session*(max(hd_hist_cluster)/max(hd_hist_session)), color='black', linewidth=2, alpha=0.9)
    hd_plot_field.plot(theta[:-1], hd_hist_cluster, color=colors[field_id], linewidth=2)
    plt.tight_layout()
    plt.title(str(spatial_firing.number_of_spikes_in_fields[cluster][field_id]) + ' spikes'
              + ' in ' + str(round(spatial_firing.time_spent_in_fields_sampling_points[cluster][field_id]/30, 2)) +' seconds', y=1.08, fontsize=12)

    plt.savefig(save_path + '/' + spatial_firing.session_id[cluster] + '_cluster_' + str(cluster + 1) + name + str(field_id + 1) + '.png', dpi=300, bbox_inches="tight")
    # plt.savefig(save_path + '/' + spatial_firing.session_id[cluster] + '_cluster_' + str(cluster + 1) + name + str(field_id + 1) + '.pdf', bbox_inches="tight")
    plt.close()
def plot_polar_head_direction_histogram(spike_hist, hd_hist, id, save_path):
    print('I will make the polar HD plots now.')

    hd_polar_fig = plt.figure()
    # hd_polar_fig.set_size_inches(5, 5, forward=True)
    ax = hd_polar_fig.add_subplot(1, 1, 1)  # specify (nrows, ncols, axnum)
    hd_hist_cluster = spike_hist
    theta = np.linspace(0, 2 * np.pi, 361)  # x axis
    ax = plt.subplot(1, 1, 1, polar=True)
    ax = plot_utility.style_polar_plot(ax)
    ax.plot(theta[:-1], hd_hist_cluster, color='red', linewidth=2)
    ax.plot(theta[:-1],
            hd_hist * (max(hd_hist_cluster) / max(hd_hist)),
            color='black',
            linewidth=2)
    # plt.tight_layout()
    max_firing_rate = np.max(hd_hist_cluster.flatten())
    plt.title(str(round(max_firing_rate, 2)) + 'Hz', y=1.08)
    #  + '\nKuiper p: ' + str(spatial_firing.hd_p[cluster])
    # plt.title('max fr: ' + str(round(spatial_firing.max_firing_rate_hd[cluster], 2)) + ' Hz' + ', preferred HD: ' + str(round(spatial_firing.preferred_HD[cluster][0], 0)) + ', hd score: ' + str(round(spatial_firing.hd_score[cluster], 2)), y=1.08, fontsize=12)
    plt.savefig(save_path + '/' + id + '_hd_polar_' + '.png', dpi=300)
    plt.close()
예제 #8
0
def plot_polar_head_direction_histogram(hd_hist, spatial_firing, prm):
    print('I will make the polar HD plots now.')
    save_path = prm.get_output_path() + '/Figures/head_direction_plots_polar'
    if os.path.exists(save_path) is False:
        os.makedirs(save_path)
    for cluster in range(len(spatial_firing)):
        cluster = spatial_firing.cluster_id.values[cluster] - 1
        hd_polar_fig = plt.figure()
        hd_polar_fig.set_size_inches(5, 5, forward=True)
        ax = hd_polar_fig.add_subplot(1, 1, 1)  # specify (nrows, ncols, axnum)
        hd_hist_cluster = spatial_firing.hd_spike_histogram[cluster]
        theta = np.linspace(0, 2*np.pi, 361)  # x axis
        ax = plt.subplot(1, 1, 1, polar=True)
        ax = plot_utility.style_polar_plot(ax)
        ax.plot(theta[:-1], hd_hist_cluster, color='red', linewidth=2)
        ax.plot(theta[:-1], hd_hist*(max(hd_hist_cluster)/max(hd_hist)), color='black', linewidth=2)
        plt.tight_layout()
        #  + '\nKuiper p: ' + str(spatial_firing.hd_p[cluster])
        plt.title('Head direction \n max fr: ' + str(round(spatial_firing.max_firing_rate_hd[cluster], 2)) + ' Hz' + ', hd score: ' + str(round(spatial_firing.hd_score[cluster], 2)) + '\n', y=1.08, fontsize=24)
        plt.savefig(save_path + '/' + spatial_firing.session_id[cluster] + '_hd_polar_' + str(cluster + 1) + '.png', dpi=300, bbox_inches="tight")
        # plt.savefig(save_path + '/' + spatial_firing.session_id[cluster] + '_hd_polar_' + str(cluster + 1) + '.pdf', bbox_inches="tight")
        plt.close()
예제 #9
0
def plot_bar_chart_for_cells_percentile_error_bar_polar(
        spatial_firing, path, animal, shuffle_type='occupancy'):
    plt.cla()
    counter = 0
    for index, cell in spatial_firing.iterrows():
        mean = np.append(cell['shuffled_means'], cell['shuffled_means'][0])
        percentile_95 = np.append(cell['error_bar_95'],
                                  cell['error_bar_95'][0])
        percentile_5 = np.append(cell['error_bar_5'], cell['error_bar_5'][0])
        shuffled_histograms_hz = cell['shuffled_histograms_hz']
        max_rate = np.round(cell.hd_histogram_real_data_hz.max(), 2)
        x_pos = np.linspace(0, 2 * np.pi, shuffled_histograms_hz.shape[1] + 1)
        ax = plt.subplot(1, 1, 1, polar=True)
        ax = plot_utility.style_polar_plot(ax)
        x_labels = [
            "0", "", "", "", "", "90", "", "", "", "", "180", "", "", "", "",
            "270", "", "", "", ""
        ]
        plt.xticks(x_pos, x_labels)
        ax.fill_between(x_pos,
                        mean - percentile_5,
                        percentile_95 + mean,
                        color='grey',
                        alpha=0.4)
        ax.plot(x_pos, mean, color='grey', linewidth=5, alpha=0.7)
        observed_data = np.append(cell.shuffled_histograms_hz[0],
                                  cell.shuffled_histograms_hz[0][0])
        ax.plot(x_pos, observed_data, color='black', linewidth=5, alpha=0.9)
        plt.ylim(0, max_rate + 1.5)
        plt.title('\n' + str(max_rate) + ' Hz', fontsize=20, y=1.08)
        plt.subplots_adjust(top=0.85)
        plt.savefig(analysis_path + str(counter) + str(cell['session_id']) +
                    str(cell['cluster_id']) + '_percentile_polar_' +
                    str(cell.percentile_value) + '_polar.png')
        plt.close()
        counter += 1
def plot_bar_chart_for_cells_percentile_error_bar_polar(
        spatial_firing,
        sampling_rate_video,
        path,
        colors=None,
        number_of_bins=20,
        smooth=False,
        add_stars=True):
    counter = 0
    for index, cell in spatial_firing.iterrows():
        if colors is None:
            observed_data_color = 'navy'
            colors = [[0, 1, 0], [1, 0.6, 0.3], [0, 1, 1], [1, 0, 1],
                      [0.7, 0.3, 1], [0.6, 0.5, 0.4],
                      [0.6, 0,
                       0]]  # green, orange, cyan, pink, purple, grey, dark red
            observed_data_color = colors[cell.field_id]

        else:
            observed_data_color = colors[index]

        mean = np.append(cell['shuffled_means'], cell['shuffled_means'][0])
        percentile_95 = np.append(cell['error_bar_95'],
                                  cell['error_bar_95'][0])
        percentile_5 = np.append(cell['error_bar_5'], cell['error_bar_5'][0])
        field_spikes_hd = cell['hd_in_field_spikes']
        time_spent_in_bins = cell['time_spent_in_bins']
        # shuffled_histograms_hz = cell['field_histograms_hz']
        real_data_hz = np.histogram(field_spikes_hd, bins=number_of_bins)[0]
        if smooth:
            real_data_hz = PostSorting.open_field_head_direction.get_hd_histogram(
                field_spikes_hd)
        real_data_hz = real_data_hz * sampling_rate_video / time_spent_in_bins

        max_rate = np.round(real_data_hz[~np.isnan(real_data_hz)].max(), 2)
        x_pos = np.linspace(0, 2 * np.pi, real_data_hz.shape[0] + 1.5)

        significant_bins_to_mark = np.where(
            cell.p_values_corrected_bars_bh < 0.05)  # indices
        significant_bins_to_mark = x_pos[significant_bins_to_mark[0]]
        y_value_markers = [max_rate + 0.5] * len(significant_bins_to_mark)
        plt.cla()
        ax = plt.subplot(1, 1, 1, polar=True)
        ax = plot_utility.style_polar_plot(ax)
        x_labels = [
            "0", "", "", "", "", "90", "", "", "", "", "180", "", "", "", "",
            "270", "", "", "", ""
        ]
        plt.xticks((np.linspace(0, 2 * np.pi, 21.50)), x_labels)
        ax.fill_between(x_pos,
                        mean - percentile_5,
                        percentile_95 + mean,
                        color='grey',
                        alpha=0.4)
        ax.plot(x_pos, mean, color='grey', linewidth=5, alpha=0.7)
        observed_data = np.append(real_data_hz, real_data_hz[0])
        ax.plot(x_pos, observed_data, color=observed_data_color, linewidth=5)
        plt.title('\n' + str(max_rate) + ' Hz', fontsize=20, y=1.08)
        if add_stars:
            if (cell.p_values_corrected_bars_bh < 0.05).sum() > 0:
                ax.scatter(significant_bins_to_mark,
                           y_value_markers,
                           c='red',
                           marker='*',
                           zorder=3,
                           s=100)
        plt.subplots_adjust(top=0.85)
        plt.savefig(path + str(counter) + str(cell['session_id']) +
                    str(cell['cluster_id']) + str(cell['field_id']) +
                    'polar.png')
        plt.close()
        counter += 1
        plt.cla()