コード例 #1
0
def plot_tuning_split_by_run_state(df, savepath):

    running_threshold = 1.0  # cm/s
    directions, contrasts = grating_params()

    MIN_SESSIONS = 3
    MIN_CELLS = 3  #per session

    areas, cres = dataset_params()
    for area in areas:
        for cre in cres:

            session_IDs = get_sessions(df, area, cre)
            num_sessions = len(session_IDs)

            if num_sessions >= MIN_SESSIONS:

                curve_dict = {}
                num_sessions_included = 0
                for i_session, session_ID in enumerate(session_IDs):

                    sweep_table = load_sweep_table(savepath, session_ID)
                    mse = load_mean_sweep_events(savepath, session_ID)
                    condition_responses, blank_responses = compute_mean_condition_responses(
                        sweep_table, mse)

                    p_all = chi_square_all_conditions(sweep_table, mse,
                                                      session_ID, savepath)
                    all_idx = np.argwhere(p_all < SIG_THRESH)[:, 0]

                    mean_sweep_running = load_mean_sweep_running(
                        session_ID, savepath)
                    is_run = mean_sweep_running >= running_threshold

                    run_responses, stat_responses, run_blank, stat_blank = condition_response_running(
                        sweep_table, mse, is_run)

                    condition_responses = center_direction_zero(
                        condition_responses)
                    run_responses = center_direction_zero(run_responses)
                    stat_responses = center_direction_zero(stat_responses)

                    peak_dir, __ = get_peak_conditions(condition_responses)

                    if len(all_idx) >= MIN_CELLS:
                        curve_dict = populate_curve_dict(
                            curve_dict, run_responses, run_blank, all_idx,
                            'all_run', peak_dir)
                        curve_dict = populate_curve_dict(
                            curve_dict, stat_responses, stat_blank, all_idx,
                            'all_stat', peak_dir)
                        num_sessions_included += 1

                if num_sessions_included >= MIN_SESSIONS:
                    plot_from_curve_dict(curve_dict, 'all', area, cre,
                                         num_sessions_included, savepath)
コード例 #2
0
def calc_DSI(condition_responses):
    
    (num_cells,num_directions,num_contrasts) = np.shape(condition_responses)
    peak_dir, peak_con = get_peak_conditions(condition_responses)
    
    DSI = np.zeros((num_cells,))
    for nc in range(num_cells):
        cell_resp = condition_responses[nc,:,peak_con[nc]]
        pref_resp = cell_resp[peak_dir[nc]]
        null_resp = cell_resp[np.mod(peak_dir[nc]+4,8)]
        DSI[nc] = (pref_resp-null_resp) / (pref_resp+null_resp)
        
    return DSI
コード例 #3
0
def get_cell_order_direction_sorted(df, area, cre, savepath):

    session_IDs = get_sessions(df, area, cre)

    resp, blank, p_all = pool_sessions(session_IDs,
                                       area + '_' + cre,
                                       savepath,
                                       scale='event')

    resp = center_direction_zero(resp)
    condition_responses = resp[p_all < SIG_THRESH]

    peak_dir, peak_con = get_peak_conditions(condition_responses)
    direction_mat = select_peak_contrast(condition_responses, peak_con)
    direction_order = sort_by_weighted_peak_direction(direction_mat)

    return direction_order
コード例 #4
0
def calc_OSI(condition_responses):
    
    (num_cells,num_directions,num_contrasts) = np.shape(condition_responses)
    peak_dir, peak_con = get_peak_conditions(condition_responses)
    
    directions, contrasts = grating_params()
    radians_per_degree = np.pi/180.0
    x_comp = np.cos(2.0*directions*radians_per_degree)
    y_comp = np.sin(2.0*directions*radians_per_degree)
    
    OSI = np.zeros((num_cells,))
    for nc in range(num_cells):
        cell_resp = condition_responses[nc,:,peak_con[nc]]
        normalized_resp = cell_resp / cell_resp.sum()

        x_proj = normalized_resp * x_comp
        y_proj = normalized_resp * y_comp
    
        OSI[nc] = np.sqrt(x_proj.sum()**2 + y_proj.sum()**2)
        
    return OSI
コード例 #5
0
def plot_single_cell_example(df,
                             savepath,
                             cre,
                             example_cell,
                             example_session_idx=0):

    directions, contrasts = grating_params()

    session_IDs = get_sessions(df, 'VISp', cre)
    session_ID = session_IDs[example_session_idx]

    mse = load_mean_sweep_events(savepath, session_ID)
    sweep_table = load_sweep_table(savepath, session_ID)

    condition_responses, __ = compute_mean_condition_responses(
        sweep_table, mse)
    condition_SEM, __ = compute_SEM_condition_responses(sweep_table, mse)
    p_all = chi_square_all_conditions(sweep_table, mse, session_ID, savepath)

    sig_resp = condition_responses[p_all < SIG_THRESH]
    sig_SEM = condition_SEM[p_all < SIG_THRESH]

    #shift zero to center:
    directions = [-135, -90, -45, 0, 45, 90, 135, 180]
    sig_resp = center_direction_zero(sig_resp)
    sig_SEM = center_direction_zero(sig_SEM)

    #full direction by contrast response heatmap
    plt.figure(figsize=(6, 4))
    ax = plt.subplot2grid((5, 5), (0, 0), rowspan=5, colspan=2)
    ax.imshow(sig_resp[example_cell],
              vmin=0.0,
              interpolation='nearest',
              aspect='auto',
              cmap='plasma')
    ax.set_ylabel('Direction (deg)', fontsize=14)
    ax.set_xlabel('Contrast (%)', fontsize=14)
    ax.set_xticks(np.arange(len(contrasts)))
    ax.set_xticklabels([str(int(100 * x)) for x in contrasts], fontsize=10)
    ax.set_yticks(np.arange(len(directions)))
    ax.set_yticklabels([str(x) for x in directions], fontsize=10)

    peak_dir_idx, peak_con_idx = get_peak_conditions(sig_resp)

    #contrast tuning at peak direction
    contrast_means = sig_resp[example_cell, peak_dir_idx[example_cell], :]
    contrast_SEMs = sig_SEM[example_cell, peak_dir_idx[example_cell], :]
    ax = plt.subplot2grid((5, 5), (0, 3), rowspan=2, colspan=2)
    ax.errorbar(np.log(contrasts), contrast_means, contrast_SEMs)
    ax.set_xticks(np.log(contrasts))
    ax.set_xticklabels([str(int(100 * x)) for x in contrasts], fontsize=10)
    ax.set_xlabel('Contrast (%)', fontsize=14)
    ax.set_ylabel('Response', fontsize=14)
    ax.spines['right'].set_visible(False)
    ax.spines['top'].set_visible(False)

    #direction tuning at peak contrast
    direction_means = sig_resp[example_cell, :, peak_con_idx[example_cell]]
    direction_SEMs = sig_SEM[example_cell, :, peak_con_idx[example_cell]]
    ax = plt.subplot2grid((5, 5), (3, 3), rowspan=2, colspan=2)
    ax.errorbar(np.arange(len(directions)), direction_means, direction_SEMs)
    ax.set_xlim(-0.07, 7.07)
    ax.set_xticks(np.arange(len(directions)))
    ax.set_xticklabels([str(x) for x in directions], fontsize=10)
    ax.set_xlabel('Direction (deg)', fontsize=14)
    ax.set_ylabel('Response', fontsize=14)
    ax.spines['right'].set_visible(False)
    ax.spines['top'].set_visible(False)

    plt.tight_layout(w_pad=-5.5, h_pad=0.1)
    plt.savefig(savepath + shorthand(cre) + '_example_cell.svg', format='svg')
    plt.close()
def plot_rasters_at_peak(sweep_table,
                         sweep_events,
                         condition_responses,
                         cell_idx,
                         session_ID,
                         cre,
                         savepath,
                         figure_format,
                         max_sweeps=15):

    peak_dir, peak_con = cu.get_peak_conditions(condition_responses)

    directions, contrasts = cu.grating_params()

    fig = plt.figure(figsize=(15, 5))

    cell_max = get_max_event_magnitude(sweep_events, cell_idx)

    direction_str = ['0', '45', '90', '135', '180', '-135', '-90', '-45']
    dir_shift = 3

    ax = plt.subplot(1, 10, 10)
    plot_blank_raster(ax, sweep_table, sweep_events, cell_idx, cell_max)
    ax.set_xlabel('Time (s)', fontsize=16)
    ax.set_title('Blanks', fontsize=16)

    for i_con, contrast in enumerate(contrasts):
        ax = plt.subplot(2, 10, 2 + i_con)

        plot_single_raster(ax, sweep_table, sweep_events, cell_idx, contrast,
                           directions[peak_dir[cell_idx]], cell_max)

        ax.set_title(str(int(100 * contrast)) + '%', fontsize=16)
        ax.set_xlabel('Time (s)', fontsize=16)

        if i_con == 0:
            ax.set_ylabel('Direction: ' + direction_str[peak_dir[cell_idx]] +
                          '$^\circ$',
                          fontsize=16)

    for i_dir, direction in enumerate(directions):

        #shift 0-degrees to center
        plot_dir = np.mod(i_dir + dir_shift, len(directions))

        ax = plt.subplot(2, 10, 11 + plot_dir)

        plot_single_raster(ax, sweep_table, sweep_events, cell_idx,
                           contrasts[peak_con[cell_idx]], direction, cell_max)

        ax.set_title(' ' + direction_str[i_dir] + '$^\circ$', fontsize=16)
        ax.set_xlabel('Time (s)', fontsize=16)

        if plot_dir == 0:
            ax.set_ylabel('Contrast: ' +
                          str(int(100 * contrasts[peak_con[cell_idx]])) + '%',
                          fontsize=16)

    fig.subplots_adjust(hspace=0)
    plt.tight_layout()

    if figure_format == 'svg':
        plt.savefig(savepath + cre + '_' + str(session_ID) + '_cell_' +
                    str(cell_idx) + '_rasters_at_peak.svg',
                    format='svg')
    else:
        plt.savefig(savepath + cre + '_' + str(session_ID) + '_cell_' +
                    str(cell_idx) + '_rasters_at_peak.png',
                    dpi=300)

    plt.close()
コード例 #7
0
def LP_HP_model_selection(session_ID,savepath,do_plot=False):
    
    if os.path.isfile(savepath+str(session_ID)+'_model_AIC.npy'):
        model_AIC = np.load(savepath+str(session_ID)+'_model_AIC.npy')
        low_pass_params = np.load(savepath+str(session_ID)+'_LP_params.npy')
        high_pass_params = np.load(savepath+str(session_ID)+'_HP_params.npy')
        band_pass_params = np.load(savepath+str(session_ID)+'_BP_params.npy')
    else:
    
        directions, __ = grating_params()
        
        sweep_table = load_sweep_table(savepath,session_ID)
        mean_sweep_events = load_mean_sweep_events(savepath,session_ID)
       
        (num_sweeps,num_cells) = np.shape(mean_sweep_events)
        
        condition_responses, __ = compute_mean_condition_responses(sweep_table,mean_sweep_events)
        
        p_all = chi_square_all_conditions(sweep_table,mean_sweep_events,session_ID,savepath)
        sig_cells = p_all < SIG_THRESH
        
        peak_dir_idx, __ = get_peak_conditions(condition_responses)
        peak_directions = directions[peak_dir_idx]
        
        high_pass_params = np.zeros((num_cells,3))
        high_pass_params[:] = np.NaN
        low_pass_params = np.zeros((num_cells,3))
        low_pass_params[:] = np.NaN
        band_pass_params = np.zeros((num_cells,4))
        band_pass_params[:] = np.NaN
        model_AIC = np.zeros((num_cells,3))
        model_AIC[:] = np.NaN
        for i_dir,direction in enumerate(directions):
            cells_pref_dir = (peak_directions == direction) & sig_cells
            
            not_60_contrast = sweep_table['Contrast'].values != 0.6
            is_direction = (sweep_table['Ori'] == direction).values
            sweeps_with_dir = np.argwhere(is_direction & not_60_contrast)[:,0]
            sweep_contrasts = 100 * sweep_table['Contrast'][sweeps_with_dir].values
            
            cell_idx = np.argwhere(cells_pref_dir)[:,0]
            for cell in cell_idx:
                
                sweep_responses = mean_sweep_events[sweeps_with_dir,cell]
                
                lp_params, lp_aic = select_over_initial_conditions(sweep_responses,sweep_contrasts,'LP')
                hp_params, hp_aic = select_over_initial_conditions(sweep_responses,sweep_contrasts,'HP')
                bp_params, bp_aic = select_over_initial_conditions(sweep_responses,sweep_contrasts,'BP')

                low_pass_params[cell,:] = lp_params
                model_AIC[cell,0] = lp_aic
                high_pass_params[cell,:] = hp_params
                model_AIC[cell,1] = hp_aic
                band_pass_params[cell,:] = bp_params
                model_AIC[cell,2] = bp_aic
                
                if do_plot:
                    
                    x_sample = np.linspace(np.log(5.0),np.log(80.0))
                    plt.figure()
                    plt.plot(np.log(sweep_contrasts),sweep_responses,'ko')
                    plt.plot(x_sample,high_pass(x_sample,high_pass_params[cell,0],high_pass_params[cell,1],high_pass_params[cell,2]),'r')
                    plt.plot(x_sample,low_pass(x_sample,low_pass_params[cell,0],low_pass_params[cell,1],low_pass_params[cell,2]),'b')
                    plt.plot(x_sample,band_pass(x_sample,band_pass_params[cell,0],band_pass_params[cell,1],band_pass_params[cell,2],band_pass_params[cell,3]),'g')
                    plt.show()
                    
        np.save(savepath+str(session_ID)+'_model_AIC.npy',model_AIC)
        np.save(savepath+str(session_ID)+'_LP_params.npy',low_pass_params)
        np.save(savepath+str(session_ID)+'_HP_params.npy',high_pass_params)
        np.save(savepath+str(session_ID)+'_BP_params.npy',band_pass_params)

    LP_c50 = low_pass_params[:,0]
    HP_c50 = high_pass_params[:,0]
    BP_rise_c50 = band_pass_params[:,0]
    BP_fall_c50 = band_pass_params[:,1]

    return LP_c50, HP_c50, BP_rise_c50, BP_fall_c50, model_AIC