def plot_all_waterfalls(df, savepath, scale='blank_subtracted_NLL'): areas, cres = dataset_params() for area in areas: for cre in cres: session_IDs = get_sessions(df, area, cre) if len(session_IDs) > 0: #sort by direction using event magnitude direction_order = get_cell_order_direction_sorted( df, area, cre, savepath) #display response significance resp, blank, p_all = pool_sessions(session_IDs, area + '_' + cre, savepath, scale=scale) resp = center_direction_zero(resp) condition_responses = resp[p_all < SIG_THRESH] dirXcon_mat = concatenate_contrasts(condition_responses) dirXcon_mat = dirXcon_mat[direction_order] dirXcon_mat = move_all_negative_to_bottom(dirXcon_mat) plot_full_waterfall( dirXcon_mat, cre, shorthand(area) + '_' + shorthand(cre) + '_full', scale, savepath)
def plot_pooled_mat(pooled_mat, area, cre, pass_str, savepath): max_resp = 80.0 cre_colors = get_cre_colors() x_tick_labels = ['-135', '-90', '-45', '0', '45', '90', '135', '180'] directions, contrasts = grating_params() plt.figure(figsize=(4.2, 4)) ax = plt.subplot(111) current_cmap = matplotlib.cm.get_cmap(name='RdBu_r') current_cmap.set_bad(color=[0.8, 0.8, 0.8]) im = ax.imshow(pooled_mat.T, vmin=-max_resp, vmax=max_resp, interpolation='nearest', aspect='auto', cmap='RdBu_r', origin='lower') ax.set_xlabel('Direction (deg)', fontsize=14) ax.set_ylabel('Contrast (%)', fontsize=14) ax.set_yticks(np.arange(len(contrasts))) ax.set_yticklabels([str(int(100 * x)) for x in contrasts], fontsize=10) ax.set_xticks(np.arange(len(directions))) ax.set_xticklabels(x_tick_labels, fontsize=10) ax.set_title(shorthand(cre) + ' population', fontsize=16, color=cre_colors[cre]) cbar = plt.colorbar( im, ax=ax, ticks=[-max_resp, -max_resp / 2.0, 0.0, max_resp / 2.0, max_resp]) cbar.set_label('Event magnitude per second (%), blank subtracted', rotation=270, labelpad=15.0) plt.savefig(savepath + shorthand(area) + '_' + shorthand(cre) + '_' + pass_str + '_summed_tuning.svg', format='svg') plt.close()
def plot_SbC_stats(df, savepath): SbC_THRESH = 0.05 cre_colors = get_cre_colors() directions, contrasts = grating_params() areas, cres = dataset_params() percent_SbC = [] labels = [] colors = [] sample_size = [] for area in areas: for cre in cres: session_IDs = get_sessions(df, area, cre) if len(session_IDs) > 0: num_cells = 0 num_SbC = 0 for session_ID in session_IDs: SbC_pval = test_SbC(session_ID, savepath) num_cells += len(SbC_pval) num_SbC += (SbC_pval < SbC_THRESH).sum() labels.append(shorthand(cre)) colors.append(cre_colors[cre]) percent_SbC.append(100.0 * num_SbC / num_cells) sample_size.append(num_cells) plt.figure(figsize=(6, 4)) ax = plt.subplot(111) for x, group in enumerate(labels): ax.bar(x, percent_SbC[x], color=colors[x]) ax.text(x, max(percent_SbC[x], 5) + 1, '(' + str(sample_size[x]) + ')', horizontalalignment='center', fontsize=8) ax.plot([-1, len(labels)], [100 * SbC_THRESH, 100 * SbC_THRESH], '--k', linewidth=2.0) ax.set_ylim(0, 30) ax.set_xlim(-1, 14) ax.set_xticks(np.arange(len(labels))) ax.set_xticklabels(labels, fontsize=10, rotation=45) ax.set_ylabel('% Suppressed by Contrast', fontsize=14) ax.spines['right'].set_visible(False) ax.spines['top'].set_visible(False) plt.savefig(savepath + 'SbC_stats.svg', format='svg') plt.close()
def plot_direction_vector_sum_by_contrast(df, savepath): areas, cres = dataset_params() directions, contrasts = grating_params() for area in areas: for cre in cres: session_IDs = get_sessions(df, area, cre) if len(session_IDs) > 0: resp, blank, p_all = pool_sessions(session_IDs, area + '_' + cre, savepath, scale='event') sig_resp = resp[p_all < SIG_THRESH] pref_dir_mat = calc_pref_direction_dist_by_contrast(sig_resp) pref_dir_mat = pref_dir_mat / np.sum( pref_dir_mat, axis=0, keepdims=True) resultant_mag = [] resultant_theta = [] for i_con, contrast in enumerate(contrasts): mag, theta = calc_vector_sum(pref_dir_mat[:, i_con]) resultant_mag.append(mag) resultant_theta.append(theta) #bootstrap CI for distribution at 5% contrast num_cells = len(sig_resp) uniform_LB, uniform_UB = uniform_direction_vector_sum( num_cells) radial_direction_figure( np.zeros((len(directions), )), np.zeros( (len(directions), )), resultant_mag, resultant_theta, uniform_LB, uniform_UB, cre, num_cells, shorthand(area) + '_' + shorthand(cre) + '_combined', savepath)
def plot_contrast_CoM(df, savepath, curve='cdf'): areas, cres = dataset_params() cre_colors = get_cre_colors() area = 'VISp' pooled_resp = [] colors = [] alphas = [] cre_labels = [] for cre in cres: session_IDs = get_sessions(df, area, cre) resp, blank, p_all = pool_sessions(session_IDs, area + '_' + cre, savepath, scale='event') pooled_resp.append(resp[p_all < SIG_THRESH]) colors.append(cre_colors[cre]) alphas.append(1.0) cre_labels.append(shorthand(cre)) center_of_mass = center_of_mass_for_list(pooled_resp) contrasts = [5, 10, 20, 40, 60, 80] plot_cdf(metric=center_of_mass, metric_labels=cre_labels, colors=colors, alphas=alphas, hist_range=(np.log(5.0), np.log(70.0)), hist_bins=200, x_label='Contrast (CoM)', x_ticks=np.log(contrasts), x_tick_labels=[str(x) for x in contrasts], save_name=shorthand(area) + '_contrast_' + curve, savepath=savepath, do_legend=True)
def model_GLM(df, savepath): for area in ['VISp']: for cre in ['Vip-IRES-Cre', 'Sst-IRES-Cre', 'Cux2-CreERT2']: session_IDs = get_sessions(df, area, cre) savename = shorthand(area) + '_' + shorthand(cre) if len(session_IDs) > 0: lamb = K_session_cross_validation(session_IDs, savename, savepath) print(shorthand(cre) + ' lambda used: ' + str(lamb)) X, y = construct_pooled_Xy(session_IDs, savepath) glm = sm.GLM(y, X, family=sm.families.Poisson()) res = glm.fit_regularized( method='elastic_net', alpha=lamb, maxiter=200, L1_wt=1.0, # 1.0:all L1, 0.0: all L2 refit=True) model_params = np.array(res.params) y_hat = np.exp( np.sum(model_params.reshape(1, len(model_params)) * X, axis=1)) plot_y(X, y_hat, area, cre, 'predicted', savepath) plot_param_CI(res, area, cre, savepath) plot_param_heatmaps(model_params, cre, shorthand(area) + '_' + shorthand(cre), savepath)
def plot_DSI_distribution(df, savepath, curve='cdf'): areas, cres = dataset_params() cre_colors = get_cre_colors() area = 'VISp' pooled_DSI = [] colors = [] alphas = [] cre_labels = [] for cre in cres: session_IDs = get_sessions(df, area, cre) resp, blank, p_all = pool_sessions(session_IDs, area + '_' + cre, savepath, scale='event') dsi = calc_DSI(resp[p_all < SIG_THRESH]) pooled_DSI.append(dsi) colors.append(cre_colors[cre]) alphas.append(1.0) cre_labels.append(shorthand(cre)) xticks = [0.0, 0.2, 0.4, 0.6, 0.8, 1.0] plot_cdf(metric=pooled_DSI, metric_labels=cre_labels, colors=colors, alphas=alphas, hist_range=(0.0, 1.0), hist_bins=200, x_label='DSI', x_ticks=xticks, x_tick_labels=[str(x) for x in xticks], save_name='V1_DSI_' + curve, savepath=savepath, do_legend=False)
def plot_from_curve_dict(curve_dict, pass_type, area, cre, num_sessions, savepath): plot_pooled_mat( np.nanmean(get_from_curve_dict(curve_dict, 'conXdir', pass_type + '_run', ''), axis=0), area, cre, pass_type + '_run', savepath) plot_pooled_mat( np.nanmean(get_from_curve_dict(curve_dict, 'conXdir', pass_type + '_stat', ''), axis=0), area, cre, pass_type + '_stat', savepath) run_contrast_pooled = get_from_curve_dict(curve_dict, 'contrast', pass_type + '_run', '') stat_contrast_pooled = get_from_curve_dict(curve_dict, 'contrast', pass_type + '_stat', '') run_direction_pooled_low = get_from_curve_dict(curve_dict, 'direction', pass_type + '_run', 'low') stat_direction_pooled_low = get_from_curve_dict(curve_dict, 'direction', pass_type + '_stat', 'low') run_aligned_pooled_low = get_from_curve_dict(curve_dict, 'aligned', pass_type + '_run', 'low') stat_aligned_pooled_low = get_from_curve_dict(curve_dict, 'aligned', pass_type + '_stat', 'low') run_direction_pooled_high = get_from_curve_dict(curve_dict, 'direction', pass_type + '_run', 'high') stat_direction_pooled_high = get_from_curve_dict(curve_dict, 'direction', pass_type + '_stat', 'high') run_aligned_pooled_high = get_from_curve_dict(curve_dict, 'aligned', pass_type + '_run', 'high') stat_aligned_pooled_high = get_from_curve_dict(curve_dict, 'aligned', pass_type + '_stat', 'high') x_con = np.log([0.025, 0.05, 0.1, 0.2, 0.4, 0.6, 0.8]) contrast_labels = ['blank', '5', '10', '20', '40', '60', '80'] x_dir = [-45, 0, 45, 90, 135, 180, 225, 270, 315] direction_labels = [ 'blank', '-135', '-90', '-45', '0', '45', '90', '135', '180' ] pref_labels = [ 'blank', '-135', '-90', '-45', '0', '45', '90', '135', '180' ] inset_x_dir = [-45, 45, 135, 225, 315] inset_direction_labels = ['blank', '-90', '0', '90', '180'] inset_x_pref = [-45, 45, 135, 225, 315] inset_pref_labels = ['blank', '-90', '0', '90', '180'] make_errorbars_plot_running( run_contrast_pooled, stat_contrast_pooled, x_con, contrast_labels, 'Contrast', area, cre, num_sessions, shorthand(area) + '_' + shorthand(cre) + '_' + pass_type, 'contrast', savepath) make_errorbars_plot_running( run_direction_pooled_low, stat_direction_pooled_low, x_dir, direction_labels, 'Direction', area, cre, num_sessions, shorthand(area) + '_' + shorthand(cre) + '_' + pass_type + '_low', 'direction', savepath) make_errorbars_plot_running( run_aligned_pooled_low, stat_aligned_pooled_low, x_dir, pref_labels, 'Direction - Peak', area, cre, num_sessions, shorthand(area) + '_' + shorthand(cre) + '_' + pass_type + '_low', 'preferred_direction', savepath) make_errorbars_plot_running(run_direction_pooled_high, stat_direction_pooled_high, x_dir, inset_direction_labels, 'Direction', area, cre, num_sessions, shorthand(area) + '_' + shorthand(cre) + '_' + pass_type + '_high', 'direction', savepath, as_inset=True, inset_x_ticks=inset_x_dir) make_errorbars_plot_running(run_aligned_pooled_high, stat_aligned_pooled_high, x_dir, inset_pref_labels, 'Direction - Peak', area, cre, num_sessions, shorthand(area) + '_' + shorthand(cre) + '_' + pass_type + '_high', 'preferred_direction', savepath, as_inset=True, inset_x_ticks=inset_x_pref)
def make_errorbars_plot_running(run_responses, stat_responses, x_values, x_tick_labels, x_label, area, cre, num_sessions, savename, plot_type, savepath, as_inset=False, inset_x_ticks=None): cre_colors = get_cre_colors() if as_inset: num_y_ticks = 3 x_tick_loc = inset_x_ticks label_font_size = 22 tick_font_size = 17 else: num_y_ticks = 5 x_tick_loc = x_values label_font_size = 14 tick_font_size = 10 min_y = 0.0 max_y = 0.04 y_ticks = np.linspace(min_y, max_y, num=num_y_ticks) y_ticks = np.round(y_ticks, decimals=3) (num_cells, num_conditions) = np.shape(run_responses) run_means, run_errors = compute_error_curve(run_responses) stat_means, stat_errors = compute_error_curve(stat_responses) min_x = x_values[0] - 0.5 * (x_values[1] - x_values[0]) max_x = x_values[-1] + 0.5 * (x_values[1] - x_values[0]) plt.figure(figsize=(4.2, 4)) ax = plt.subplot(111) ax.plot([x_values[0], x_values[-1]], [run_means[0], run_means[0]], color=cre_colors[cre], linestyle='dotted', linewidth=2.0) ax.plot([x_values[0], x_values[-1]], [stat_means[0], stat_means[0]], color=cre_colors[cre], linestyle='dotted', linewidth=2.0, alpha=0.5) ax.errorbar([x_values[0]], [run_means[0]], yerr=run_errors[:, 0].reshape(2, 1), color=cre_colors[cre], linewidth=3, capsize=5, elinewidth=2, markeredgewidth=2) ax.errorbar([x_values[0]], [stat_means[0]], yerr=stat_errors[:, 0].reshape(2, 1), color=cre_colors[cre], linewidth=3, capsize=5, elinewidth=2, markeredgewidth=2, alpha=0.5) ax.errorbar(x_values[1:], run_means[1:], yerr=run_errors[:, 1:], color=cre_colors[cre], linewidth=3, capsize=5, elinewidth=2, markeredgewidth=2) ax.errorbar(x_values[1:], stat_means[1:], yerr=stat_errors[:, 1:], color=cre_colors[cre], linewidth=3, capsize=5, elinewidth=2, markeredgewidth=2, alpha=0.5) if x_label == 'Contrast': x_label = 'Contrast (%)' elif x_label.find('Direction') != -1: x_label = x_label + ' (deg)' if not as_inset: ax.set_ylabel('Mean event magnitude (a.u.)', fontsize=label_font_size) ax.set_xlabel(x_label, fontsize=label_font_size) ax.set_xticks(x_tick_loc) ax.set_xticklabels(x_tick_labels, fontsize=tick_font_size) ax.set_xlim(min_x, max_x) ax.set_yticks(y_ticks) ax.set_yticklabels([str(x) for x in y_ticks], fontsize=tick_font_size) ax.set_ylim(min_y, max_y) ax.spines['right'].set_visible(False) ax.spines['top'].set_visible(False) if plot_type == 'contrast': ax.text(x_values[4], 0.9 * max_y, 'n = ' + str(num_cells) + ' (' + str(num_sessions) + ')', fontsize=10, horizontalalignment='center') if shorthand(cre) == 'Sst': ax.text(x_values[0], 0.024, 'run', fontsize=14, color=cre_colors[cre]) ax.text(x_values[0], 0.020, 'stat', fontsize=14, color=cre_colors[cre], alpha=0.5) else: ax.text(x_values[-2], 0.024, 'run', fontsize=14, color=cre_colors[cre]) ax.text(x_values[-2], 0.020, 'stat', fontsize=14, color=cre_colors[cre], alpha=0.5) ax.set_aspect('auto') plt.tight_layout() plt.savefig(savepath + savename + '_' + plot_type + '_errorbars.svg', format='svg') plt.close()
def plot_single_cell_example(df, savepath, cre, example_cell, example_session_idx=0): directions, contrasts = grating_params() session_IDs = get_sessions(df, 'VISp', cre) session_ID = session_IDs[example_session_idx] mse = load_mean_sweep_events(savepath, session_ID) sweep_table = load_sweep_table(savepath, session_ID) condition_responses, __ = compute_mean_condition_responses( sweep_table, mse) condition_SEM, __ = compute_SEM_condition_responses(sweep_table, mse) p_all = chi_square_all_conditions(sweep_table, mse, session_ID, savepath) sig_resp = condition_responses[p_all < SIG_THRESH] sig_SEM = condition_SEM[p_all < SIG_THRESH] #shift zero to center: directions = [-135, -90, -45, 0, 45, 90, 135, 180] sig_resp = center_direction_zero(sig_resp) sig_SEM = center_direction_zero(sig_SEM) #full direction by contrast response heatmap plt.figure(figsize=(6, 4)) ax = plt.subplot2grid((5, 5), (0, 0), rowspan=5, colspan=2) ax.imshow(sig_resp[example_cell], vmin=0.0, interpolation='nearest', aspect='auto', cmap='plasma') ax.set_ylabel('Direction (deg)', fontsize=14) ax.set_xlabel('Contrast (%)', fontsize=14) ax.set_xticks(np.arange(len(contrasts))) ax.set_xticklabels([str(int(100 * x)) for x in contrasts], fontsize=10) ax.set_yticks(np.arange(len(directions))) ax.set_yticklabels([str(x) for x in directions], fontsize=10) peak_dir_idx, peak_con_idx = get_peak_conditions(sig_resp) #contrast tuning at peak direction contrast_means = sig_resp[example_cell, peak_dir_idx[example_cell], :] contrast_SEMs = sig_SEM[example_cell, peak_dir_idx[example_cell], :] ax = plt.subplot2grid((5, 5), (0, 3), rowspan=2, colspan=2) ax.errorbar(np.log(contrasts), contrast_means, contrast_SEMs) ax.set_xticks(np.log(contrasts)) ax.set_xticklabels([str(int(100 * x)) for x in contrasts], fontsize=10) ax.set_xlabel('Contrast (%)', fontsize=14) ax.set_ylabel('Response', fontsize=14) ax.spines['right'].set_visible(False) ax.spines['top'].set_visible(False) #direction tuning at peak contrast direction_means = sig_resp[example_cell, :, peak_con_idx[example_cell]] direction_SEMs = sig_SEM[example_cell, :, peak_con_idx[example_cell]] ax = plt.subplot2grid((5, 5), (3, 3), rowspan=2, colspan=2) ax.errorbar(np.arange(len(directions)), direction_means, direction_SEMs) ax.set_xlim(-0.07, 7.07) ax.set_xticks(np.arange(len(directions))) ax.set_xticklabels([str(x) for x in directions], fontsize=10) ax.set_xlabel('Direction (deg)', fontsize=14) ax.set_ylabel('Response', fontsize=14) ax.spines['right'].set_visible(False) ax.spines['top'].set_visible(False) plt.tight_layout(w_pad=-5.5, h_pad=0.1) plt.savefig(savepath + shorthand(cre) + '_example_cell.svg', format='svg') plt.close()
def radial_direction_figure(x_coor, y_coor, resultant_mag, resultant_theta, CI_LB, CI_UB, cre, num_cells, savename, savepath, max_radius=0.75): color = get_cre_colors()[cre] directions, contrasts = grating_params() unit_circle_x = np.linspace(-1.0, 1.0, 100) unit_circle_y = np.sqrt(1.0 - unit_circle_x**2) plt.figure(figsize=(4, 4)) ax = plt.subplot(111) outer_CI = Circle((0, 0), CI_UB / max_radius, facecolor=[0.6, 0.6, 0.6]) inner_CI = Circle((0, 0), CI_LB / max_radius, facecolor=[1.0, 1.0, 1.0]) ax.add_patch(outer_CI) ax.add_patch(inner_CI) #spokes for i, direction in enumerate(directions): ax.plot([0, np.cos(np.pi * direction / 180.0)], [0, np.sin(np.pi * direction / 180.0)], 'k--', linewidth=1.0) #outer ring ax.plot(unit_circle_x, unit_circle_y, 'k', linewidth=2.0) ax.plot(unit_circle_x, -unit_circle_y, 'k', linewidth=2.0) ax.plot(0.25 * unit_circle_x / max_radius, 0.25 * unit_circle_y / max_radius, '--k', linewidth=1.0) ax.plot(0.25 * unit_circle_x / max_radius, -0.25 * unit_circle_y / max_radius, '--k', linewidth=1.0) ax.plot(0.5 * unit_circle_x / max_radius, 0.5 * unit_circle_y / max_radius, '--k', linewidth=1.0) ax.plot(0.5 * unit_circle_x / max_radius, -0.5 * unit_circle_y / max_radius, '--k', linewidth=1.0) ax.plot(unit_circle_x, unit_circle_y, 'k', linewidth=2.0) ax.plot(unit_circle_x, -unit_circle_y, 'k', linewidth=2.0) #center ax.plot(unit_circle_x / 200.0, unit_circle_y / 200.0, 'k', linewidth=2.0) ax.plot(unit_circle_x / 200.0, -unit_circle_y / 200.0, 'k', linewidth=2.0) ax.plot(np.array(x_coor) / max_radius, np.array(y_coor) / max_radius, color=color, linewidth=2.0) contrast_colors = get_contrast_colors() for i, mag in enumerate(resultant_mag[::-1]): ax.arrow(0, 0, mag * np.cos(-resultant_theta[len(contrasts) - i - 1]) / (max_radius), mag * np.sin(-resultant_theta[len(contrasts) - i - 1]) / (max_radius), color=contrast_colors[i], linewidth=2.0, head_width=0.03) #labels ax.text(0, 1.02, 'U', fontsize=12, horizontalalignment='center', verticalalignment='bottom') ax.text(0, -1.02, 'D', fontsize=12, horizontalalignment='center', verticalalignment='top') ax.text(1.02, 0, 'T', fontsize=12, verticalalignment='center', horizontalalignment='left') ax.text(-1.02, 0, 'N', fontsize=12, verticalalignment='center', horizontalalignment='right') ax.text(-1, 0.99, shorthand(cre), fontsize=16, horizontalalignment='left') ax.text(0.73, 0.99, '(n=' + str(num_cells) + ')', fontsize=10, horizontalalignment='left') ax.text(.73, -.73, '45', fontsize=10, horizontalalignment='left', verticalalignment='top') ax.text(-.78, -.75, '135', fontsize=10, horizontalalignment='right', verticalalignment='top') ax.text(-.73, .73, '-135', fontsize=10, verticalalignment='bottom', horizontalalignment='right') ax.text(.73, .73, '-45', fontsize=10, verticalalignment='bottom', horizontalalignment='left') ax.text(.81, -.71, '$^\circ$', fontsize=18, horizontalalignment='left', verticalalignment='top') ax.text(-.69, -.73, '$^\circ$', fontsize=18, horizontalalignment='right', verticalalignment='top') ax.text(-.64, .69, '$^\circ$', fontsize=18, verticalalignment='bottom', horizontalalignment='right') ax.text(.85, .69, '$^\circ$', fontsize=18, verticalalignment='bottom', horizontalalignment='left') ax.set_xlim(-1.2, 1.2) ax.set_ylim(-1.2, 1.2) plt.axis('equal') plt.axis('off') plt.savefig(savepath + savename + '_radial_direction_tuning.svg', format='svg') plt.close()
def plot_full_waterfall(cells_by_condition, cre, save_name, scale, savepath, do_colorbar=False): resp_max = 4.0 resp_min = -4.0 directions, contrasts = grating_params() num_contrasts = len(contrasts) num_directions = len(directions) (num_cells, num_conditions) = np.shape(cells_by_condition) cre_colors = get_cre_colors() plt.figure(figsize=(10, 4)) ax = plt.subplot(111) im = ax.imshow(cells_by_condition, vmin=resp_min, vmax=resp_max, interpolation='nearest', aspect='auto', cmap='RdBu_r') #dividing lines between contrasts for i_con in range(num_contrasts - 1): ax.plot([(i_con + 1) * num_directions - 0.5, (i_con + 1) * num_directions - 0.5], [0, num_cells - 1], 'k', linewidth=2.0) ax.set_ylabel(shorthand(cre) + ' cell number', fontsize=14, color=cre_colors[cre], labelpad=-6) ax.set_xlabel('Contrast (%)', fontsize=14, labelpad=-5) ax.set_xticks(num_directions * np.arange(num_contrasts) + (num_directions / 2) - 0.5) ax.set_xticklabels([str(int(100 * x)) for x in contrasts], fontsize=12) ax.set_yticks([0, num_cells - 1]) ax.set_yticklabels(['0', str(num_cells - 1)], fontsize=12) if do_colorbar: percentile_ticks = [ 0.0001, 0.001, 0.01, 0.1, 0.5, 0.9, 0.99, 0.999, 0.9999 ] NLL_ticks = percentile_to_NLL(percentile_ticks, num_shuffles=200000) cbar = plt.colorbar(im, ax=ax, ticks=NLL_ticks, orientation='horizontal') cbar.ax.set_xticklabels([str(100 * x) for x in percentile_ticks], fontsize=12) cbar.set_label('Response Percentile', fontsize=16, rotation=0, labelpad=15.0) plt.tight_layout() plt.savefig(savepath + save_name + '_' + scale + '.svg', format='svg') plt.close()
def decode_direction_from_running(df, savepath, save_format='svg'): directions, contrasts = grating_params() running_dict = {} areas, cres = dataset_params() for area in ['VISp']: for cre in cres: celltype = shorthand(area) + ' ' + shorthand(cre) session_IDs = get_sessions(df, area, cre) num_sessions = len(session_IDs) if num_sessions > 0: savename = shorthand(area) + '_' + shorthand( cre) + '_running_direction_decoder.npy' if os.path.isfile(savepath + savename): #decoder_performance = np.load(savepath+savename) running_performance = np.load( savepath + shorthand(area) + '_' + shorthand(cre) + '_running_direction_decoder.npy') else: #decoder_performance = [] running_performance = [] for i_session, session_ID in enumerate(session_IDs): #mean_sweep_events = load_mean_sweep_events(savepath,session_ID) mean_sweep_running = load_mean_sweep_running( session_ID, savepath) sweep_table = load_sweep_table(savepath, session_ID) #(num_sweeps,num_cells) = np.shape(mean_sweep_events) is_blank = sweep_table['Ori'].isnull().values blank_sweeps = np.argwhere(is_blank)[:, 0] sweep_directions = sweep_table['Ori'].values sweep_categories = sweep_directions.copy() sweep_categories[blank_sweeps] = 360 sweep_categories = sweep_categories.astype(np.int) / 45 is_low = sweep_table['Contrast'].values < 0.2 sweeps_included = np.argwhere(is_low)[:, 0] sweep_categories = sweep_categories[sweeps_included] #mean_sweep_events = mean_sweep_events[sweeps_included] mean_sweep_running = mean_sweep_running[ sweeps_included] #decode front-to-back motion # is_front_to_back = (sweep_categories==0) | (sweep_categories==7) # front_to_back_sweeps = np.argwhere(is_front_to_back)[:,0] # rest_sweeps = np.argwhere(~is_front_to_back)[:,0] # sweep_categories[front_to_back_sweeps] = 0 # sweep_categories[rest_sweeps] = 1 running_performance.append( decode_direction( mean_sweep_running.reshape( len(sweeps_included), 1), sweep_categories)) #for nc in range(num_cells): #decoder_performance.append(decode_direction(mean_sweep_events,sweep_categories)) #decoder_performance = np.array(decoder_performance) running_performance = np.array(running_performance) #np.save(savepath+savename,decoder_performance) np.save( savepath + shorthand(area) + '_' + shorthand(cre) + '_running_direction_decoder.npy', running_performance) #print celltype + ': ' + str(np.mean(decoder_performance)) print(celltype + ': ' + str(np.mean(running_performance))) running_dict[shorthand(cre)] = running_performance cre_colors = get_cre_colors() plt.figure(figsize=(6, 4)) ax = plt.subplot(111) ax.plot([-1, 6], [12.5, 12.5], 'k--') label_loc = [] labels = [] for i, cre in enumerate(cres): session_performance = running_dict[shorthand(cre)] ax.plot(i * np.ones((len(session_performance), )), 100.0 * session_performance, '.', markersize=4.0, color=cre_colors[cre]) ax.plot([i - 0.4, i + 0.4], [ 100.0 * session_performance.mean(), 100.0 * session_performance.mean() ], color=cre_colors[cre], linewidth=3) label_loc.append(i) labels.append(shorthand(cre)) ax.set_xticks(label_loc) ax.set_xticklabels(labels, rotation=45, fontsize=10) ax.set_ylim(0, 25) ax.spines['right'].set_visible(False) ax.spines['top'].set_visible(False) ax.set_xlim(-1, 14) #ax.text(3,20,'Predict direction from running',fontsize=14,horizontalalignment='center') ax.set_ylabel('Decoding performance (%)', fontsize=14) if save_format == 'svg': plt.savefig(savepath + 'running_decoder.svg', format='svg') else: plt.savefig(savepath + 'running_decoder.png', dpi=300) plt.close()
def plot_param_CI(res, area, cre, savepath, PARAM_TOL=1E-2, save_format='svg'): model_params = np.array(res.params) terms = unpack_params(model_params) CI = np.array(res.conf_int(alpha=0.05)) CI_lb = unpack_params(CI[:, 0]) CI_ub = unpack_params(CI[:, 1]) plot_order = [ 'blank', 'run', 'dir', 'con', 'dirXrun', 'conXrun', 'dirXcon', 'dirXconXrun' ] directions, contrasts = grating_params() directions = [-135, -90, -45, 0, 45, 90, 135, 180] x_labels = {} x_labels['blank'] = ['blank'] x_labels['blankXrun'] = ['blank X run'] x_labels['run'] = ['run'] x_labels['dir'] = [str(x) for x in directions] x_labels['con'] = [str(int(100 * x)) for x in contrasts] x_labels['dirXrun'] = [str(x) for x in directions] x_labels['conXrun'] = [str(int(100 * x)) for x in contrasts] plt.figure(figsize=(20, 4.5)) savename = shorthand(area) + '_' + shorthand(cre) cre_colors = get_cre_colors() ax = plt.subplot(111) curr_x = 0 x_ticks = [] x_ticklabels = [] for i, param_name in enumerate(plot_order): param_means = terms[param_name] param_CI_lb = CI_lb[param_name] param_CI_ub = CI_ub[param_name] #center directions on zero if param_name == 'dirXcon' or param_name == 'dirXconXrun' or param_name == 'dir' or param_name == 'dirXrun': param_means = center_dir_on_zero(param_means) param_CI_lb = center_dir_on_zero(param_CI_lb) param_CI_ub = center_dir_on_zero(param_CI_ub) #handle parameters that are not 1D arrays if type(param_means) == np.float64: param_means = np.array([param_means]) param_CI_lb = np.array([param_CI_lb]) param_CI_ub = np.array([param_CI_ub]) elif param_name == 'dirXcon' or param_name == 'dirXconXrun': param_means = param_means.flatten() param_CI_lb = param_CI_lb.flatten() param_CI_ub = param_CI_ub.flatten() param_errs = CI_to_errorbars(param_means, param_CI_lb, param_CI_ub) num_params = np.shape(param_errs)[1] #for dirXcon terms, only plot non-zero values if param_name == 'dirXcon' or param_name == 'dirXconXrun': non_zero_idx = np.argwhere((param_CI_ub < -PARAM_TOL) | (param_CI_lb > PARAM_TOL))[:, 0] num_params = len(non_zero_idx) cond_tick_labels = [] if num_params > 0: param_means = param_means[non_zero_idx] param_errs = param_errs[:, non_zero_idx] for i_cond, idx in enumerate(non_zero_idx): i_dir = int(idx / 6) i_con = int(idx % 6) cond_tick_labels.append( str(int(100 * contrasts[i_con])) + '%,' + str(int(directions[i_dir]))) # pad params to make all plots equal size ticks_to_plot = 6 num_to_pad = ticks_to_plot - num_params for i_pad in range(num_to_pad): cond_tick_labels.append('') x_labels[param_name] = cond_tick_labels x_values = np.arange(curr_x, curr_x + 2 * num_params, 2) #double spacing else: ticks_to_plot = num_params x_values = np.arange(curr_x, curr_x + num_params) if num_params > 0: ax.errorbar(x_values, param_means, yerr=param_errs, fmt='o', color=cre_colors[cre], linewidth=3, capsize=5, elinewidth=2, markeredgewidth=2) for i_x, x in enumerate(x_values): x_ticks.append(x) x_ticklabels.append(x_labels[param_name][i_x]) curr_x += ticks_to_plot + 1 ax.plot([-1, curr_x], [0, 0], 'k', linewidth=1.0) ax.set_ylabel('Weight', fontsize=14) ax.set_xlim([-1, curr_x]) ax.set_xticks(x_ticks) ax.set_xticklabels(x_ticklabels) y_max = 2.5 y_min = -1.5 y_ticks = [-1, 0, 1, 2] ax.set_yticks(y_ticks) ax.set_yticklabels([str(int(y)) for y in y_ticks]) ax.set_ylim([y_min, y_max]) ax.spines['right'].set_visible(False) ax.spines['top'].set_visible(False) if save_format == 'svg': plt.savefig(savepath + savename + '_param_CI.svg', format='svg') else: plt.savefig(savepath + savename + '_param_CI.png', dpi=300) plt.close()
def plot_peak_response_distribution(run_aligned_pooled_low, stat_aligned_pooled_low, run_aligned_pooled_high, stat_aligned_pooled_high, area, cre, savename, savepath): directions, contrasts = grating_params() plt.figure(figsize=(7, 4)) ax = plt.subplot(111) MAX_CELLS = 15000 cre_colors = get_cre_colors() resp_dict = {} BLANK_IDX = 0 resp_dict = add_group_to_dict(resp_dict, run_aligned_pooled_low, BLANK_IDX, 'run blank') resp_dict = add_group_to_dict(resp_dict, stat_aligned_pooled_low, BLANK_IDX, 'stat blank') directions = [-135, -90, -45, 0, 45, 90, 135, 180] contrasts = [0.05, 0.8] for run_state in ['run', 'stat']: for i_con, contrast in enumerate(contrasts): if run_state == 'run' and contrast == 0.05: resps = run_aligned_pooled_low elif run_state == 'run' and contrast == 0.8: resps = run_aligned_pooled_high elif run_state == 'stat' and contrast == 0.05: resps = stat_aligned_pooled_low else: resps = stat_aligned_pooled_high for i_dir, direction in enumerate(directions): group_name = run_state + ' ' + str(direction) + ' ' + str( int(100 * contrast)) + '%' resp_dict = add_group_to_dict(resp_dict, resps, 1 + i_dir, group_name) plot_order = [('space1', ''), ('run blank', ''), ('stat blank', ''), ('space2', '')] curr_space = 3 for run_state in ['run', 'stat']: for i_con, contrast in enumerate(contrasts): for i_dir, direction in enumerate(directions): plot_order.append((run_state + ' ' + str(direction) + ' ' + str(int(100 * contrast)) + '%', '')) plot_order.append(('space' + str(curr_space), '')) curr_space += 1 colors = ['#9f9f9f'] #blanks for i in range(len(plot_order)): colors.append(cre_colors[cre]) cre_palette = sns.color_palette(colors) resp_df = pd.DataFrame(np.zeros((MAX_CELLS, 3)), columns=('Response to Preferred Direction', 'cell_type', 'cre')) curr_cell = 0 labels = [] x_pos = [] dist = [] dir_idx = 0 for line, (group, cre_name) in enumerate(plot_order): if group.find('space') == -1: resp_mag = resp_dict[group] resp_mag = resp_mag[np.argwhere(np.isfinite(resp_mag))[:, 0]] num_cells = len(resp_mag) resp_df['Response to Preferred Direction'][curr_cell:( curr_cell + num_cells)] = resp_mag resp_df['cre'][curr_cell:(curr_cell + num_cells)] = cre_name resp_df['cell_type'][curr_cell:(curr_cell + num_cells)] = group curr_cell += num_cells x_pos.append(line) dist.append(resp_mag) if group.find('blank') != -1: if group.find('run') != -1: labels.append('run') else: labels.append('stat') else: labels.append(str(directions[dir_idx])) dir_idx += 1 if dir_idx == len(directions): dir_idx = 0 else: resp_df['Response to Preferred Direction'][curr_cell] = np.NaN resp_df['cre'][curr_cell] = 'blank' resp_df['cell_type'][curr_cell] = group curr_cell += 1 resp_df = resp_df.drop(index=np.arange(curr_cell, MAX_CELLS)) ax = sns.swarmplot(x='cell_type', y='Response to Preferred Direction', hue='cre', size=1.0, palette=cre_palette, data=resp_df) ax.set_xticks(np.array(x_pos)) ax.set_xticklabels(labels, fontsize=4.5, rotation=0) ax.legend_.remove() for i, d in enumerate(dist): plot_quartiles(ax, d, x_pos[i]) ax.set_ylim(-20, 400) ax.set_ylabel('Event magnitude per second (%)', fontsize=12) ax.set_xlabel( 'Blank run 5% contrast run 80% contrast stat 5% contrast stat 80% contrast ', fontsize=9) ax.spines['right'].set_visible(False) ax.spines['top'].set_visible(False) plt.tight_layout() plt.savefig(savepath + shorthand(area) + '_' + shorthand(cre) + '_' + savename + '_cell_response_distribution.svg', format='svg') plt.close()