def compute_blank_subtracted_NLL(session_ID,savepath,num_shuffles=200000):
    
    if os.path.isfile(savepath+str(session_ID)+'_blank_subtracted_NLL.npy'):
        condition_NLL = np.load(savepath+str(session_ID)+'_blank_subtracted_NLL.npy')
        blank_NLL = np.load(savepath+str(session_ID)+'_blank_subtracted_blank_NLL.npy')
    else:
        sweep_table = load_sweep_table(savepath,session_ID)
        mean_sweep_events = load_mean_sweep_events(savepath,session_ID)
        
        (num_sweeps,num_cells) = np.shape(mean_sweep_events) 
        
        condition_responses, blank_sweep_responses = compute_mean_condition_responses(sweep_table,mean_sweep_events)
        condition_responses = np.swapaxes(condition_responses,0,2)
        condition_responses = np.swapaxes(condition_responses,0,1)
        
        directions, contrasts = grating_params()
        
        # different conditions can have different number of trials...
        trials_per_condition, num_blanks = compute_trials_per_condition(sweep_table)
        unique_trial_counts = np.unique(trials_per_condition.flatten())
        
        trial_count_mat = np.tile(trials_per_condition,reps=(num_cells,1,1))
        trial_count_mat = np.swapaxes(trial_count_mat,0,2)
        trial_count_mat = np.swapaxes(trial_count_mat,0,1)
        
        blank_shuffle_sweeps = np.random.choice(num_sweeps,size=(num_shuffles*num_blanks,))
        blank_shuffle_responses = mean_sweep_events[blank_shuffle_sweeps].reshape(num_shuffles,num_blanks,num_cells)
        blank_null_dist = blank_shuffle_responses.mean(axis=1)
        
        condition_NLL = np.zeros((len(directions),len(contrasts),num_cells))
        for trial_count in unique_trial_counts:
            
            #create null distribution and compute condition NLL
            shuffle_sweeps = np.random.choice(num_sweeps,size=(num_shuffles*trial_count,))
            shuffle_responses = mean_sweep_events[shuffle_sweeps].reshape(num_shuffles,trial_count,num_cells)
            
            null_diff_dist = shuffle_responses.mean(axis=1) - blank_null_dist
            actual_diffs = condition_responses.reshape(len(directions),len(contrasts),1,num_cells) - blank_sweep_responses.reshape(1,1,1,num_cells)
            resp_above_null = null_diff_dist.reshape(1,1,num_shuffles,num_cells) < actual_diffs
            percentile = resp_above_null.mean(axis=2)
            NLL = percentile_to_NLL(percentile,num_shuffles)
        
            has_count = trial_count_mat == trial_count
            condition_NLL = np.where(has_count,NLL,condition_NLL)
            
        #repeat for blank sweeps
        blank_null_dist_2 = blank_null_dist[np.random.choice(num_shuffles,size=num_shuffles),:]
        blank_null_diff_dist = blank_null_dist_2 - blank_null_dist
        actual_diffs = 0.0
        resp_above_null = blank_null_diff_dist < actual_diffs
        percentile = resp_above_null.mean(axis=0)
        blank_NLL = percentile_to_NLL(percentile,num_shuffles)
        
        np.save(savepath+str(session_ID)+'_blank_subtracted_NLL.npy',condition_NLL)
        np.save(savepath+str(session_ID)+'_blank_subtracted_blank_NLL.npy',blank_NLL)
        
    condition_NLL = np.swapaxes(condition_NLL,0,2)
    condition_NLL = np.swapaxes(condition_NLL,1,2)
        
    return condition_NLL, blank_NLL
def construct_pooled_Xy(session_IDs, savepath, RUN_THRESH=1.0):

    y = None
    for session_ID in session_IDs:

        sweep_table = load_sweep_table(savepath, session_ID)
        mean_sweep_events = load_mean_sweep_events(savepath, session_ID)

        pvals = chi_square_all_conditions(sweep_table, mean_sweep_events,
                                          session_ID, savepath)
        sig_cells = pvals < 0.01
        mean_sweep_events = mean_sweep_events[:, sig_cells]

        (num_sweeps, session_cells) = np.shape(mean_sweep_events)

        mean_sweep_running = load_mean_sweep_running(session_ID, savepath)
        is_run = mean_sweep_running >= RUN_THRESH

        X, session_y = construct_session_Xy(sweep_table, is_run,
                                            mean_sweep_events)

        if y is not None:
            y = np.append(y, session_y, axis=1)
        else:
            y = session_y

    y = np.nanmean(y, axis=1)

    conditions_sampled = np.argwhere(np.isfinite(y))[:, 0]

    y *= 3000

    return X[conditions_sampled], y[conditions_sampled]
Ejemplo n.º 3
0
def plot_tuning_split_by_run_state(df, savepath):

    running_threshold = 1.0  # cm/s
    directions, contrasts = grating_params()

    MIN_SESSIONS = 3
    MIN_CELLS = 3  #per session

    areas, cres = dataset_params()
    for area in areas:
        for cre in cres:

            session_IDs = get_sessions(df, area, cre)
            num_sessions = len(session_IDs)

            if num_sessions >= MIN_SESSIONS:

                curve_dict = {}
                num_sessions_included = 0
                for i_session, session_ID in enumerate(session_IDs):

                    sweep_table = load_sweep_table(savepath, session_ID)
                    mse = load_mean_sweep_events(savepath, session_ID)
                    condition_responses, blank_responses = compute_mean_condition_responses(
                        sweep_table, mse)

                    p_all = chi_square_all_conditions(sweep_table, mse,
                                                      session_ID, savepath)
                    all_idx = np.argwhere(p_all < SIG_THRESH)[:, 0]

                    mean_sweep_running = load_mean_sweep_running(
                        session_ID, savepath)
                    is_run = mean_sweep_running >= running_threshold

                    run_responses, stat_responses, run_blank, stat_blank = condition_response_running(
                        sweep_table, mse, is_run)

                    condition_responses = center_direction_zero(
                        condition_responses)
                    run_responses = center_direction_zero(run_responses)
                    stat_responses = center_direction_zero(stat_responses)

                    peak_dir, __ = get_peak_conditions(condition_responses)

                    if len(all_idx) >= MIN_CELLS:
                        curve_dict = populate_curve_dict(
                            curve_dict, run_responses, run_blank, all_idx,
                            'all_run', peak_dir)
                        curve_dict = populate_curve_dict(
                            curve_dict, stat_responses, stat_blank, all_idx,
                            'all_stat', peak_dir)
                        num_sessions_included += 1

                if num_sessions_included >= MIN_SESSIONS:
                    plot_from_curve_dict(curve_dict, 'all', area, cre,
                                         num_sessions_included, savepath)
def pool_sessions(session_IDs,pool_name,savepath,scale='blank_subtracted_NLL'):
    
    if os.path.isfile(savepath+pool_name+'_condition_responses_'+scale+'.npy'):
    
        pooled_condition_responses = np.load(savepath+pool_name+'_condition_responses_'+scale+'.npy')
        pooled_blank_responses = np.load(savepath+pool_name+'_blank_responses_'+scale+'.npy')
        p_vals_all = np.load(savepath+pool_name+'_chisq_all.npy')
    
    else:
        
        print('Pooling sessions for ' + pool_name)
        
        directions, contrasts = grating_params()
        
        MAX_CELLS = 5000
        pooled_condition_responses = np.zeros((MAX_CELLS,len(directions),len(contrasts)))
        pooled_blank_responses = np.zeros((MAX_CELLS,))
        p_vals_all = np.zeros((MAX_CELLS,))
        curr_cell = 0
        for session_ID in session_IDs:
            
            print(str(session_ID))
            
            mse = load_mean_sweep_events(savepath,session_ID)
            sweep_table = load_sweep_table(savepath,session_ID)
    
            if scale=='event':
                condition_responses, blank_responses = compute_mean_condition_responses(sweep_table,mse)
            elif scale=='blank_subtracted_NLL':
                condition_responses, blank_responses = compute_blank_subtracted_NLL(session_ID,savepath)
                
            p_all = chi_square_all_conditions(sweep_table,mse,session_ID,savepath)

            session_cells = len(p_all)
            
            pooled_condition_responses[curr_cell:(curr_cell+session_cells)] = condition_responses
            pooled_blank_responses[curr_cell:(curr_cell+session_cells)] = blank_responses
            p_vals_all[curr_cell:(curr_cell+session_cells)] = p_all
            curr_cell += session_cells
            
        pooled_condition_responses = pooled_condition_responses[:curr_cell]
        pooled_blank_responses = pooled_blank_responses[:curr_cell]
        p_vals_all = p_vals_all[:curr_cell]
        
        np.save(savepath+pool_name+'_condition_responses_'+scale+'.npy',pooled_condition_responses)
        np.save(savepath+pool_name+'_blank_responses_'+scale+'.npy',pooled_blank_responses)
        np.save(savepath+pool_name+'_chisq_all.npy',p_vals_all)
    
        print('Done.')
    
    return pooled_condition_responses, pooled_blank_responses, p_vals_all
Ejemplo n.º 5
0
def plot_single_cell_example(df,
                             savepath,
                             cre,
                             example_cell,
                             example_session_idx=0):

    directions, contrasts = grating_params()

    session_IDs = get_sessions(df, 'VISp', cre)
    session_ID = session_IDs[example_session_idx]

    mse = load_mean_sweep_events(savepath, session_ID)
    sweep_table = load_sweep_table(savepath, session_ID)

    condition_responses, __ = compute_mean_condition_responses(
        sweep_table, mse)
    condition_SEM, __ = compute_SEM_condition_responses(sweep_table, mse)
    p_all = chi_square_all_conditions(sweep_table, mse, session_ID, savepath)

    sig_resp = condition_responses[p_all < SIG_THRESH]
    sig_SEM = condition_SEM[p_all < SIG_THRESH]

    #shift zero to center:
    directions = [-135, -90, -45, 0, 45, 90, 135, 180]
    sig_resp = center_direction_zero(sig_resp)
    sig_SEM = center_direction_zero(sig_SEM)

    #full direction by contrast response heatmap
    plt.figure(figsize=(6, 4))
    ax = plt.subplot2grid((5, 5), (0, 0), rowspan=5, colspan=2)
    ax.imshow(sig_resp[example_cell],
              vmin=0.0,
              interpolation='nearest',
              aspect='auto',
              cmap='plasma')
    ax.set_ylabel('Direction (deg)', fontsize=14)
    ax.set_xlabel('Contrast (%)', fontsize=14)
    ax.set_xticks(np.arange(len(contrasts)))
    ax.set_xticklabels([str(int(100 * x)) for x in contrasts], fontsize=10)
    ax.set_yticks(np.arange(len(directions)))
    ax.set_yticklabels([str(x) for x in directions], fontsize=10)

    peak_dir_idx, peak_con_idx = get_peak_conditions(sig_resp)

    #contrast tuning at peak direction
    contrast_means = sig_resp[example_cell, peak_dir_idx[example_cell], :]
    contrast_SEMs = sig_SEM[example_cell, peak_dir_idx[example_cell], :]
    ax = plt.subplot2grid((5, 5), (0, 3), rowspan=2, colspan=2)
    ax.errorbar(np.log(contrasts), contrast_means, contrast_SEMs)
    ax.set_xticks(np.log(contrasts))
    ax.set_xticklabels([str(int(100 * x)) for x in contrasts], fontsize=10)
    ax.set_xlabel('Contrast (%)', fontsize=14)
    ax.set_ylabel('Response', fontsize=14)
    ax.spines['right'].set_visible(False)
    ax.spines['top'].set_visible(False)

    #direction tuning at peak contrast
    direction_means = sig_resp[example_cell, :, peak_con_idx[example_cell]]
    direction_SEMs = sig_SEM[example_cell, :, peak_con_idx[example_cell]]
    ax = plt.subplot2grid((5, 5), (3, 3), rowspan=2, colspan=2)
    ax.errorbar(np.arange(len(directions)), direction_means, direction_SEMs)
    ax.set_xlim(-0.07, 7.07)
    ax.set_xticks(np.arange(len(directions)))
    ax.set_xticklabels([str(x) for x in directions], fontsize=10)
    ax.set_xlabel('Direction (deg)', fontsize=14)
    ax.set_ylabel('Response', fontsize=14)
    ax.spines['right'].set_visible(False)
    ax.spines['top'].set_visible(False)

    plt.tight_layout(w_pad=-5.5, h_pad=0.1)
    plt.savefig(savepath + shorthand(cre) + '_example_cell.svg', format='svg')
    plt.close()
def plot_single_cell_tuning_curves(session_ID, savepath, cre, example_cell,
                                   plot_path, figure_format):

    directions, contrasts = cu.grating_params()

    mse = 3000.0 * cu.load_mean_sweep_events(savepath, session_ID)
    sweep_table = cu.load_sweep_table(savepath, session_ID)

    condition_responses, blank_responses = cm.compute_mean_condition_responses(
        sweep_table, mse)
    condition_SEM, __ = cm.compute_SEM_condition_responses(sweep_table, mse)

    #shift zero to center:
    directions = [-135, -90, -45, 0, 45, 90, 135, 180]
    condition_resp = cu.center_direction_zero(condition_responses)
    condition_SEM = cu.center_direction_zero(condition_SEM)

    #full direction by contrast response heatmap
    plt.figure(figsize=(7, 4))
    ax = plt.subplot2grid((5, 5), (0, 3), rowspan=5, colspan=2)
    im = ax.imshow(condition_resp[example_cell],
                   vmin=0.0,
                   interpolation='nearest',
                   aspect='auto',
                   cmap='plasma')
    ax.set_ylabel('Direction (deg)', fontsize=12)
    ax.set_xlabel('Contrast (%)', fontsize=12)
    ax.set_xticks(np.arange(len(contrasts)))
    ax.set_xticklabels([str(int(100 * x)) for x in contrasts], fontsize=12)
    ax.set_yticks(np.arange(len(directions)))
    ax.set_yticklabels([str(x) for x in directions], fontsize=12)
    cbar = plt.colorbar(im, ax=ax)
    cbar.set_label('Event magnitude per second (%)', fontsize=12)

    peak_dir_idx, peak_con_idx = cm.get_peak_conditions(condition_resp)

    #contrast tuning at peak direction
    contrast_means = condition_resp[example_cell,
                                    peak_dir_idx[example_cell], :]
    contrast_SEMs = condition_SEM[example_cell, peak_dir_idx[example_cell], :]

    y_max = 1.1 * np.max(contrast_means + contrast_SEMs)

    ax = plt.subplot2grid((5, 5), (0, 0), rowspan=2, colspan=2)
    ax.errorbar(np.log(contrasts),
                contrast_means,
                contrast_SEMs,
                linewidth=0.7,
                color='b')
    ax.plot([np.log(contrasts[0]), np.log(contrasts[-1])],
            [blank_responses[example_cell], blank_responses[example_cell]],
            linewidth=0.7,
            linestyle='--',
            color='b')
    ax.set_xticks(np.log(contrasts))
    ax.set_xticklabels([str(int(100 * x)) for x in contrasts], fontsize=12)
    ax.tick_params(axis='y', labelsize=12)
    ax.set_xlabel('Contrast (%)', fontsize=12)
    ax.set_ylabel('Event magnitude per second (%)  ', fontsize=12)
    ax.set_ylim([0, y_max])
    ax.spines['right'].set_visible(False)
    ax.spines['top'].set_visible(False)
    ax.set_title('@ ' + str(directions[peak_dir_idx[example_cell]]) +
                 ' degrees',
                 fontsize=12)

    #direction tuning at peak contrast
    direction_means = condition_resp[example_cell, :,
                                     peak_con_idx[example_cell]]
    direction_SEMs = condition_SEM[example_cell, :, peak_con_idx[example_cell]]
    ax = plt.subplot2grid((5, 5), (3, 0), rowspan=2, colspan=2)
    ax.errorbar(np.arange(len(directions)),
                direction_means,
                direction_SEMs,
                linewidth=0.7,
                color='b')
    ax.plot([0, len(directions) - 1],
            [blank_responses[example_cell], blank_responses[example_cell]],
            linestyle='--',
            color='b',
            linewidth=0.7)
    ax.set_xlim(-0.07, 7.07)
    ax.set_xticks(np.arange(len(directions)))
    ax.set_xticklabels([str(x) for x in directions], fontsize=12)
    ax.tick_params(axis='y', labelsize=12)

    ax.set_xlabel('Direction (deg)', fontsize=12)
    #ax.set_ylabel('Response',fontsize=14)
    ax.set_ylim([0, y_max])
    ax.spines['right'].set_visible(False)
    ax.spines['top'].set_visible(False)
    ax.set_title('@ ' + str(int(100 * contrasts[peak_con_idx[example_cell]])) +
                 '% contrast',
                 fontsize=12)

    plt.tight_layout(w_pad=-5.5, h_pad=0.1)

    if figure_format == 'svg':
        plt.savefig(plot_path + cre + '_' + str(session_ID) + '_cell_' +
                    str(example_cell) + '_tuning_curves.svg',
                    format='svg')
    else:
        plt.savefig(plot_path + cre + '_' + str(session_ID) + '_cell_' +
                    str(example_cell) + '_tuning_curves.png',
                    dpi=300)
    plt.close()
Ejemplo n.º 7
0
def LP_HP_model_selection(session_ID,savepath,do_plot=False):
    
    if os.path.isfile(savepath+str(session_ID)+'_model_AIC.npy'):
        model_AIC = np.load(savepath+str(session_ID)+'_model_AIC.npy')
        low_pass_params = np.load(savepath+str(session_ID)+'_LP_params.npy')
        high_pass_params = np.load(savepath+str(session_ID)+'_HP_params.npy')
        band_pass_params = np.load(savepath+str(session_ID)+'_BP_params.npy')
    else:
    
        directions, __ = grating_params()
        
        sweep_table = load_sweep_table(savepath,session_ID)
        mean_sweep_events = load_mean_sweep_events(savepath,session_ID)
       
        (num_sweeps,num_cells) = np.shape(mean_sweep_events)
        
        condition_responses, __ = compute_mean_condition_responses(sweep_table,mean_sweep_events)
        
        p_all = chi_square_all_conditions(sweep_table,mean_sweep_events,session_ID,savepath)
        sig_cells = p_all < SIG_THRESH
        
        peak_dir_idx, __ = get_peak_conditions(condition_responses)
        peak_directions = directions[peak_dir_idx]
        
        high_pass_params = np.zeros((num_cells,3))
        high_pass_params[:] = np.NaN
        low_pass_params = np.zeros((num_cells,3))
        low_pass_params[:] = np.NaN
        band_pass_params = np.zeros((num_cells,4))
        band_pass_params[:] = np.NaN
        model_AIC = np.zeros((num_cells,3))
        model_AIC[:] = np.NaN
        for i_dir,direction in enumerate(directions):
            cells_pref_dir = (peak_directions == direction) & sig_cells
            
            not_60_contrast = sweep_table['Contrast'].values != 0.6
            is_direction = (sweep_table['Ori'] == direction).values
            sweeps_with_dir = np.argwhere(is_direction & not_60_contrast)[:,0]
            sweep_contrasts = 100 * sweep_table['Contrast'][sweeps_with_dir].values
            
            cell_idx = np.argwhere(cells_pref_dir)[:,0]
            for cell in cell_idx:
                
                sweep_responses = mean_sweep_events[sweeps_with_dir,cell]
                
                lp_params, lp_aic = select_over_initial_conditions(sweep_responses,sweep_contrasts,'LP')
                hp_params, hp_aic = select_over_initial_conditions(sweep_responses,sweep_contrasts,'HP')
                bp_params, bp_aic = select_over_initial_conditions(sweep_responses,sweep_contrasts,'BP')

                low_pass_params[cell,:] = lp_params
                model_AIC[cell,0] = lp_aic
                high_pass_params[cell,:] = hp_params
                model_AIC[cell,1] = hp_aic
                band_pass_params[cell,:] = bp_params
                model_AIC[cell,2] = bp_aic
                
                if do_plot:
                    
                    x_sample = np.linspace(np.log(5.0),np.log(80.0))
                    plt.figure()
                    plt.plot(np.log(sweep_contrasts),sweep_responses,'ko')
                    plt.plot(x_sample,high_pass(x_sample,high_pass_params[cell,0],high_pass_params[cell,1],high_pass_params[cell,2]),'r')
                    plt.plot(x_sample,low_pass(x_sample,low_pass_params[cell,0],low_pass_params[cell,1],low_pass_params[cell,2]),'b')
                    plt.plot(x_sample,band_pass(x_sample,band_pass_params[cell,0],band_pass_params[cell,1],band_pass_params[cell,2],band_pass_params[cell,3]),'g')
                    plt.show()
                    
        np.save(savepath+str(session_ID)+'_model_AIC.npy',model_AIC)
        np.save(savepath+str(session_ID)+'_LP_params.npy',low_pass_params)
        np.save(savepath+str(session_ID)+'_HP_params.npy',high_pass_params)
        np.save(savepath+str(session_ID)+'_BP_params.npy',band_pass_params)

    LP_c50 = low_pass_params[:,0]
    HP_c50 = high_pass_params[:,0]
    BP_rise_c50 = band_pass_params[:,0]
    BP_fall_c50 = band_pass_params[:,1]

    return LP_c50, HP_c50, BP_rise_c50, BP_fall_c50, model_AIC