コード例 #1
0
def plot_all_waterfalls(df, savepath, scale='blank_subtracted_NLL'):

    areas, cres = dataset_params()
    for area in areas:
        for cre in cres:
            session_IDs = get_sessions(df, area, cre)

            if len(session_IDs) > 0:

                #sort by direction using event magnitude
                direction_order = get_cell_order_direction_sorted(
                    df, area, cre, savepath)

                #display response significance
                resp, blank, p_all = pool_sessions(session_IDs,
                                                   area + '_' + cre,
                                                   savepath,
                                                   scale=scale)
                resp = center_direction_zero(resp)
                condition_responses = resp[p_all < SIG_THRESH]

                dirXcon_mat = concatenate_contrasts(condition_responses)
                dirXcon_mat = dirXcon_mat[direction_order]
                dirXcon_mat = move_all_negative_to_bottom(dirXcon_mat)

                plot_full_waterfall(
                    dirXcon_mat, cre,
                    shorthand(area) + '_' + shorthand(cre) + '_full', scale,
                    savepath)
コード例 #2
0
def plot_tuning_split_by_run_state(df, savepath):

    running_threshold = 1.0  # cm/s
    directions, contrasts = grating_params()

    MIN_SESSIONS = 3
    MIN_CELLS = 3  #per session

    areas, cres = dataset_params()
    for area in areas:
        for cre in cres:

            session_IDs = get_sessions(df, area, cre)
            num_sessions = len(session_IDs)

            if num_sessions >= MIN_SESSIONS:

                curve_dict = {}
                num_sessions_included = 0
                for i_session, session_ID in enumerate(session_IDs):

                    sweep_table = load_sweep_table(savepath, session_ID)
                    mse = load_mean_sweep_events(savepath, session_ID)
                    condition_responses, blank_responses = compute_mean_condition_responses(
                        sweep_table, mse)

                    p_all = chi_square_all_conditions(sweep_table, mse,
                                                      session_ID, savepath)
                    all_idx = np.argwhere(p_all < SIG_THRESH)[:, 0]

                    mean_sweep_running = load_mean_sweep_running(
                        session_ID, savepath)
                    is_run = mean_sweep_running >= running_threshold

                    run_responses, stat_responses, run_blank, stat_blank = condition_response_running(
                        sweep_table, mse, is_run)

                    condition_responses = center_direction_zero(
                        condition_responses)
                    run_responses = center_direction_zero(run_responses)
                    stat_responses = center_direction_zero(stat_responses)

                    peak_dir, __ = get_peak_conditions(condition_responses)

                    if len(all_idx) >= MIN_CELLS:
                        curve_dict = populate_curve_dict(
                            curve_dict, run_responses, run_blank, all_idx,
                            'all_run', peak_dir)
                        curve_dict = populate_curve_dict(
                            curve_dict, stat_responses, stat_blank, all_idx,
                            'all_stat', peak_dir)
                        num_sessions_included += 1

                if num_sessions_included >= MIN_SESSIONS:
                    plot_from_curve_dict(curve_dict, 'all', area, cre,
                                         num_sessions_included, savepath)
コード例 #3
0
def plot_SbC_stats(df, savepath):

    SbC_THRESH = 0.05

    cre_colors = get_cre_colors()
    directions, contrasts = grating_params()

    areas, cres = dataset_params()
    percent_SbC = []
    labels = []
    colors = []
    sample_size = []
    for area in areas:
        for cre in cres:

            session_IDs = get_sessions(df, area, cre)

            if len(session_IDs) > 0:

                num_cells = 0
                num_SbC = 0
                for session_ID in session_IDs:
                    SbC_pval = test_SbC(session_ID, savepath)
                    num_cells += len(SbC_pval)
                    num_SbC += (SbC_pval < SbC_THRESH).sum()

                labels.append(shorthand(cre))
                colors.append(cre_colors[cre])
                percent_SbC.append(100.0 * num_SbC / num_cells)
                sample_size.append(num_cells)

    plt.figure(figsize=(6, 4))
    ax = plt.subplot(111)
    for x, group in enumerate(labels):
        ax.bar(x, percent_SbC[x], color=colors[x])
        ax.text(x,
                max(percent_SbC[x], 5) + 1,
                '(' + str(sample_size[x]) + ')',
                horizontalalignment='center',
                fontsize=8)
    ax.plot([-1, len(labels)], [100 * SbC_THRESH, 100 * SbC_THRESH],
            '--k',
            linewidth=2.0)
    ax.set_ylim(0, 30)
    ax.set_xlim(-1, 14)
    ax.set_xticks(np.arange(len(labels)))
    ax.set_xticklabels(labels, fontsize=10, rotation=45)
    ax.set_ylabel('% Suppressed by Contrast', fontsize=14)
    ax.spines['right'].set_visible(False)
    ax.spines['top'].set_visible(False)
    plt.savefig(savepath + 'SbC_stats.svg', format='svg')
    plt.close()
コード例 #4
0
def get_cell_order_direction_sorted(df, area, cre, savepath):

    session_IDs = get_sessions(df, area, cre)

    resp, blank, p_all = pool_sessions(session_IDs,
                                       area + '_' + cre,
                                       savepath,
                                       scale='event')

    resp = center_direction_zero(resp)
    condition_responses = resp[p_all < SIG_THRESH]

    peak_dir, peak_con = get_peak_conditions(condition_responses)
    direction_mat = select_peak_contrast(condition_responses, peak_con)
    direction_order = sort_by_weighted_peak_direction(direction_mat)

    return direction_order
コード例 #5
0
def plot_DSI_distribution(df, savepath, curve='cdf'):

    areas, cres = dataset_params()
    cre_colors = get_cre_colors()

    area = 'VISp'
    pooled_DSI = []
    colors = []
    alphas = []
    cre_labels = []
    for cre in cres:
        session_IDs = get_sessions(df, area, cre)

        resp, blank, p_all = pool_sessions(session_IDs,
                                           area + '_' + cre,
                                           savepath,
                                           scale='event')

        dsi = calc_DSI(resp[p_all < SIG_THRESH])

        pooled_DSI.append(dsi)
        colors.append(cre_colors[cre])
        alphas.append(1.0)
        cre_labels.append(shorthand(cre))

    xticks = [0.0, 0.2, 0.4, 0.6, 0.8, 1.0]

    plot_cdf(metric=pooled_DSI,
             metric_labels=cre_labels,
             colors=colors,
             alphas=alphas,
             hist_range=(0.0, 1.0),
             hist_bins=200,
             x_label='DSI',
             x_ticks=xticks,
             x_tick_labels=[str(x) for x in xticks],
             save_name='V1_DSI_' + curve,
             savepath=savepath,
             do_legend=False)
コード例 #6
0
def plot_direction_vector_sum_by_contrast(df, savepath):

    areas, cres = dataset_params()
    directions, contrasts = grating_params()

    for area in areas:
        for cre in cres:
            session_IDs = get_sessions(df, area, cre)

            if len(session_IDs) > 0:

                resp, blank, p_all = pool_sessions(session_IDs,
                                                   area + '_' + cre,
                                                   savepath,
                                                   scale='event')
                sig_resp = resp[p_all < SIG_THRESH]

                pref_dir_mat = calc_pref_direction_dist_by_contrast(sig_resp)
                pref_dir_mat = pref_dir_mat / np.sum(
                    pref_dir_mat, axis=0, keepdims=True)

                resultant_mag = []
                resultant_theta = []
                for i_con, contrast in enumerate(contrasts):
                    mag, theta = calc_vector_sum(pref_dir_mat[:, i_con])
                    resultant_mag.append(mag)
                    resultant_theta.append(theta)

                #bootstrap CI for distribution at 5% contrast
                num_cells = len(sig_resp)
                uniform_LB, uniform_UB = uniform_direction_vector_sum(
                    num_cells)

                radial_direction_figure(
                    np.zeros((len(directions), )), np.zeros(
                        (len(directions), )), resultant_mag, resultant_theta,
                    uniform_LB, uniform_UB, cre, num_cells,
                    shorthand(area) + '_' + shorthand(cre) + '_combined',
                    savepath)
コード例 #7
0
def plot_contrast_CoM(df, savepath, curve='cdf'):

    areas, cres = dataset_params()
    cre_colors = get_cre_colors()

    area = 'VISp'
    pooled_resp = []
    colors = []
    alphas = []
    cre_labels = []
    for cre in cres:
        session_IDs = get_sessions(df, area, cre)

        resp, blank, p_all = pool_sessions(session_IDs,
                                           area + '_' + cre,
                                           savepath,
                                           scale='event')
        pooled_resp.append(resp[p_all < SIG_THRESH])
        colors.append(cre_colors[cre])
        alphas.append(1.0)
        cre_labels.append(shorthand(cre))

    center_of_mass = center_of_mass_for_list(pooled_resp)

    contrasts = [5, 10, 20, 40, 60, 80]

    plot_cdf(metric=center_of_mass,
             metric_labels=cre_labels,
             colors=colors,
             alphas=alphas,
             hist_range=(np.log(5.0), np.log(70.0)),
             hist_bins=200,
             x_label='Contrast (CoM)',
             x_ticks=np.log(contrasts),
             x_tick_labels=[str(x) for x in contrasts],
             save_name=shorthand(area) + '_contrast_' + curve,
             savepath=savepath,
             do_legend=True)
コード例 #8
0
def model_GLM(df, savepath):

    for area in ['VISp']:
        for cre in ['Vip-IRES-Cre', 'Sst-IRES-Cre', 'Cux2-CreERT2']:

            session_IDs = get_sessions(df, area, cre)
            savename = shorthand(area) + '_' + shorthand(cre)
            if len(session_IDs) > 0:

                lamb = K_session_cross_validation(session_IDs, savename,
                                                  savepath)

                print(shorthand(cre) + ' lambda used: ' + str(lamb))

                X, y = construct_pooled_Xy(session_IDs, savepath)

                glm = sm.GLM(y, X, family=sm.families.Poisson())

                res = glm.fit_regularized(
                    method='elastic_net',
                    alpha=lamb,
                    maxiter=200,
                    L1_wt=1.0,  # 1.0:all L1, 0.0: all L2
                    refit=True)

                model_params = np.array(res.params)

                y_hat = np.exp(
                    np.sum(model_params.reshape(1, len(model_params)) * X,
                           axis=1))
                plot_y(X, y_hat, area, cre, 'predicted', savepath)

                plot_param_CI(res, area, cre, savepath)

                plot_param_heatmaps(model_params, cre,
                                    shorthand(area) + '_' + shorthand(cre),
                                    savepath)
コード例 #9
0
def plot_single_cell_example(df,
                             savepath,
                             cre,
                             example_cell,
                             example_session_idx=0):

    directions, contrasts = grating_params()

    session_IDs = get_sessions(df, 'VISp', cre)
    session_ID = session_IDs[example_session_idx]

    mse = load_mean_sweep_events(savepath, session_ID)
    sweep_table = load_sweep_table(savepath, session_ID)

    condition_responses, __ = compute_mean_condition_responses(
        sweep_table, mse)
    condition_SEM, __ = compute_SEM_condition_responses(sweep_table, mse)
    p_all = chi_square_all_conditions(sweep_table, mse, session_ID, savepath)

    sig_resp = condition_responses[p_all < SIG_THRESH]
    sig_SEM = condition_SEM[p_all < SIG_THRESH]

    #shift zero to center:
    directions = [-135, -90, -45, 0, 45, 90, 135, 180]
    sig_resp = center_direction_zero(sig_resp)
    sig_SEM = center_direction_zero(sig_SEM)

    #full direction by contrast response heatmap
    plt.figure(figsize=(6, 4))
    ax = plt.subplot2grid((5, 5), (0, 0), rowspan=5, colspan=2)
    ax.imshow(sig_resp[example_cell],
              vmin=0.0,
              interpolation='nearest',
              aspect='auto',
              cmap='plasma')
    ax.set_ylabel('Direction (deg)', fontsize=14)
    ax.set_xlabel('Contrast (%)', fontsize=14)
    ax.set_xticks(np.arange(len(contrasts)))
    ax.set_xticklabels([str(int(100 * x)) for x in contrasts], fontsize=10)
    ax.set_yticks(np.arange(len(directions)))
    ax.set_yticklabels([str(x) for x in directions], fontsize=10)

    peak_dir_idx, peak_con_idx = get_peak_conditions(sig_resp)

    #contrast tuning at peak direction
    contrast_means = sig_resp[example_cell, peak_dir_idx[example_cell], :]
    contrast_SEMs = sig_SEM[example_cell, peak_dir_idx[example_cell], :]
    ax = plt.subplot2grid((5, 5), (0, 3), rowspan=2, colspan=2)
    ax.errorbar(np.log(contrasts), contrast_means, contrast_SEMs)
    ax.set_xticks(np.log(contrasts))
    ax.set_xticklabels([str(int(100 * x)) for x in contrasts], fontsize=10)
    ax.set_xlabel('Contrast (%)', fontsize=14)
    ax.set_ylabel('Response', fontsize=14)
    ax.spines['right'].set_visible(False)
    ax.spines['top'].set_visible(False)

    #direction tuning at peak contrast
    direction_means = sig_resp[example_cell, :, peak_con_idx[example_cell]]
    direction_SEMs = sig_SEM[example_cell, :, peak_con_idx[example_cell]]
    ax = plt.subplot2grid((5, 5), (3, 3), rowspan=2, colspan=2)
    ax.errorbar(np.arange(len(directions)), direction_means, direction_SEMs)
    ax.set_xlim(-0.07, 7.07)
    ax.set_xticks(np.arange(len(directions)))
    ax.set_xticklabels([str(x) for x in directions], fontsize=10)
    ax.set_xlabel('Direction (deg)', fontsize=14)
    ax.set_ylabel('Response', fontsize=14)
    ax.spines['right'].set_visible(False)
    ax.spines['top'].set_visible(False)

    plt.tight_layout(w_pad=-5.5, h_pad=0.1)
    plt.savefig(savepath + shorthand(cre) + '_example_cell.svg', format='svg')
    plt.close()
コード例 #10
0
def decode_direction_from_running(df, savepath, save_format='svg'):

    directions, contrasts = grating_params()

    running_dict = {}

    areas, cres = dataset_params()
    for area in ['VISp']:
        for cre in cres:

            celltype = shorthand(area) + ' ' + shorthand(cre)

            session_IDs = get_sessions(df, area, cre)
            num_sessions = len(session_IDs)

            if num_sessions > 0:

                savename = shorthand(area) + '_' + shorthand(
                    cre) + '_running_direction_decoder.npy'
                if os.path.isfile(savepath + savename):
                    #decoder_performance = np.load(savepath+savename)
                    running_performance = np.load(
                        savepath + shorthand(area) + '_' + shorthand(cre) +
                        '_running_direction_decoder.npy')
                else:
                    #decoder_performance = []
                    running_performance = []
                    for i_session, session_ID in enumerate(session_IDs):

                        #mean_sweep_events = load_mean_sweep_events(savepath,session_ID)
                        mean_sweep_running = load_mean_sweep_running(
                            session_ID, savepath)

                        sweep_table = load_sweep_table(savepath, session_ID)

                        #(num_sweeps,num_cells) =  np.shape(mean_sweep_events)

                        is_blank = sweep_table['Ori'].isnull().values
                        blank_sweeps = np.argwhere(is_blank)[:, 0]
                        sweep_directions = sweep_table['Ori'].values

                        sweep_categories = sweep_directions.copy()
                        sweep_categories[blank_sweeps] = 360
                        sweep_categories = sweep_categories.astype(np.int) / 45

                        is_low = sweep_table['Contrast'].values < 0.2
                        sweeps_included = np.argwhere(is_low)[:, 0]

                        sweep_categories = sweep_categories[sweeps_included]
                        #mean_sweep_events = mean_sweep_events[sweeps_included]
                        mean_sweep_running = mean_sweep_running[
                            sweeps_included]

                        #decode front-to-back motion
                        #                        is_front_to_back = (sweep_categories==0) |  (sweep_categories==7)
                        #                        front_to_back_sweeps = np.argwhere(is_front_to_back)[:,0]
                        #                        rest_sweeps = np.argwhere(~is_front_to_back)[:,0]
                        #                        sweep_categories[front_to_back_sweeps] = 0
                        #                        sweep_categories[rest_sweeps] = 1

                        running_performance.append(
                            decode_direction(
                                mean_sweep_running.reshape(
                                    len(sweeps_included), 1),
                                sweep_categories))
                        #for nc in range(num_cells):
                        #decoder_performance.append(decode_direction(mean_sweep_events,sweep_categories))
                    #decoder_performance = np.array(decoder_performance)
                    running_performance = np.array(running_performance)
                    #np.save(savepath+savename,decoder_performance)
                    np.save(
                        savepath + shorthand(area) + '_' + shorthand(cre) +
                        '_running_direction_decoder.npy', running_performance)
                #print celltype + ': ' + str(np.mean(decoder_performance))
                print(celltype + ': ' + str(np.mean(running_performance)))
                running_dict[shorthand(cre)] = running_performance

    cre_colors = get_cre_colors()

    plt.figure(figsize=(6, 4))
    ax = plt.subplot(111)
    ax.plot([-1, 6], [12.5, 12.5], 'k--')
    label_loc = []
    labels = []
    for i, cre in enumerate(cres):
        session_performance = running_dict[shorthand(cre)]
        ax.plot(i * np.ones((len(session_performance), )),
                100.0 * session_performance,
                '.',
                markersize=4.0,
                color=cre_colors[cre])
        ax.plot([i - 0.4, i + 0.4], [
            100.0 * session_performance.mean(),
            100.0 * session_performance.mean()
        ],
                color=cre_colors[cre],
                linewidth=3)
        label_loc.append(i)
        labels.append(shorthand(cre))
    ax.set_xticks(label_loc)
    ax.set_xticklabels(labels, rotation=45, fontsize=10)
    ax.set_ylim(0, 25)
    ax.spines['right'].set_visible(False)
    ax.spines['top'].set_visible(False)
    ax.set_xlim(-1, 14)
    #ax.text(3,20,'Predict direction from running',fontsize=14,horizontalalignment='center')
    ax.set_ylabel('Decoding performance (%)', fontsize=14)

    if save_format == 'svg':
        plt.savefig(savepath + 'running_decoder.svg', format='svg')
    else:
        plt.savefig(savepath + 'running_decoder.png', dpi=300)

    plt.close()