예제 #1
0
def make_direction_legend(savepath):

    directions, contrasts = grating_params()
    num_conditions = len(directions) * len(contrasts)

    directions = [225, 270, 315, 0, 45, 90, 135, 180]
    arrow_length = 0.4
    empty_im = np.zeros((num_conditions, num_conditions))
    plt.figure(figsize=(10, 10))
    ax = plt.subplot(111)
    ax.imshow(empty_im,
              vmin=-1.0,
              vmax=1.0,
              interpolation='none',
              aspect='auto',
              cmap='RdBu_r')
    for i_con in range(len(contrasts)):
        for i_dir in range(len(directions)):
            x_center = i_dir + i_con * len(directions)
            y_center = np.shape(empty_im)[0] / 2

            dx = arrow_length * np.cos(directions[i_dir] * np.pi / 180.)
            dy = arrow_length * np.sin(directions[i_dir] * np.pi / 180.)

            x = x_center - dx / 2.0
            y = y_center - dy / 2.0

            ax.arrow(x, y, dx, dy, head_width=0.2)

    plt.savefig(savepath + 'arrow_legend.svg', format='svg')
    plt.close()
def condition_vars(direction, contrast, run_state):

    directions, contrasts = grating_params()

    var_bool = []

    # terms without run interaction
    var_bool.append(contrast is None)  #blank
    for i_dir, this_direction in enumerate(directions):
        var_bool.append(direction == this_direction)
    for i_con, this_contrast in enumerate(contrasts):
        var_bool.append(contrast == this_contrast)
    for i_dir, this_direction in enumerate(directions):
        for i_con, this_contrast in enumerate(contrasts):
            var_bool.append(direction == this_direction
                            and contrast == this_contrast)

    # terms with run interaction
    var_bool.append(contrast is None and run_state)  #blankXrun
    for i_dir, this_direction in enumerate(directions):
        var_bool.append(direction == this_direction and run_state)
    for i_con, this_contrast in enumerate(contrasts):
        var_bool.append(contrast == this_contrast and run_state)
    for i_dir, this_direction in enumerate(directions):
        for i_con, this_contrast in enumerate(contrasts):
            var_bool.append(direction == this_direction
                            and contrast == this_contrast and run_state)

    #run var
    var_bool.append(run_state)

    #const var
    var_bool.append(True)

    return np.array(var_bool)
def compute_blank_subtracted_NLL(session_ID,savepath,num_shuffles=200000):
    
    if os.path.isfile(savepath+str(session_ID)+'_blank_subtracted_NLL.npy'):
        condition_NLL = np.load(savepath+str(session_ID)+'_blank_subtracted_NLL.npy')
        blank_NLL = np.load(savepath+str(session_ID)+'_blank_subtracted_blank_NLL.npy')
    else:
        sweep_table = load_sweep_table(savepath,session_ID)
        mean_sweep_events = load_mean_sweep_events(savepath,session_ID)
        
        (num_sweeps,num_cells) = np.shape(mean_sweep_events) 
        
        condition_responses, blank_sweep_responses = compute_mean_condition_responses(sweep_table,mean_sweep_events)
        condition_responses = np.swapaxes(condition_responses,0,2)
        condition_responses = np.swapaxes(condition_responses,0,1)
        
        directions, contrasts = grating_params()
        
        # different conditions can have different number of trials...
        trials_per_condition, num_blanks = compute_trials_per_condition(sweep_table)
        unique_trial_counts = np.unique(trials_per_condition.flatten())
        
        trial_count_mat = np.tile(trials_per_condition,reps=(num_cells,1,1))
        trial_count_mat = np.swapaxes(trial_count_mat,0,2)
        trial_count_mat = np.swapaxes(trial_count_mat,0,1)
        
        blank_shuffle_sweeps = np.random.choice(num_sweeps,size=(num_shuffles*num_blanks,))
        blank_shuffle_responses = mean_sweep_events[blank_shuffle_sweeps].reshape(num_shuffles,num_blanks,num_cells)
        blank_null_dist = blank_shuffle_responses.mean(axis=1)
        
        condition_NLL = np.zeros((len(directions),len(contrasts),num_cells))
        for trial_count in unique_trial_counts:
            
            #create null distribution and compute condition NLL
            shuffle_sweeps = np.random.choice(num_sweeps,size=(num_shuffles*trial_count,))
            shuffle_responses = mean_sweep_events[shuffle_sweeps].reshape(num_shuffles,trial_count,num_cells)
            
            null_diff_dist = shuffle_responses.mean(axis=1) - blank_null_dist
            actual_diffs = condition_responses.reshape(len(directions),len(contrasts),1,num_cells) - blank_sweep_responses.reshape(1,1,1,num_cells)
            resp_above_null = null_diff_dist.reshape(1,1,num_shuffles,num_cells) < actual_diffs
            percentile = resp_above_null.mean(axis=2)
            NLL = percentile_to_NLL(percentile,num_shuffles)
        
            has_count = trial_count_mat == trial_count
            condition_NLL = np.where(has_count,NLL,condition_NLL)
            
        #repeat for blank sweeps
        blank_null_dist_2 = blank_null_dist[np.random.choice(num_shuffles,size=num_shuffles),:]
        blank_null_diff_dist = blank_null_dist_2 - blank_null_dist
        actual_diffs = 0.0
        resp_above_null = blank_null_diff_dist < actual_diffs
        percentile = resp_above_null.mean(axis=0)
        blank_NLL = percentile_to_NLL(percentile,num_shuffles)
        
        np.save(savepath+str(session_ID)+'_blank_subtracted_NLL.npy',condition_NLL)
        np.save(savepath+str(session_ID)+'_blank_subtracted_blank_NLL.npy',blank_NLL)
        
    condition_NLL = np.swapaxes(condition_NLL,0,2)
    condition_NLL = np.swapaxes(condition_NLL,1,2)
        
    return condition_NLL, blank_NLL
예제 #4
0
def calc_pref_direction_dist_by_contrast(condition_responses):

    directions, contrasts = grating_params()

    pref_dir_mat = np.zeros((len(directions), len(contrasts)))
    for i_con, contrast in enumerate(contrasts):

        max_resps = np.max(condition_responses[:, :, i_con], axis=1)
        num_same_max = np.sum(condition_responses[:, :,
                                                  i_con] == max_resps.reshape(
                                                      len(condition_responses),
                                                      1),
                              axis=1)

        #multi peak cells: distribute across the directions with the same response magnitude
        multi_peak_cells = np.argwhere(num_same_max > 1)[:, 0]
        for nc in range(len(multi_peak_cells)):
            is_same_as_max = condition_responses[multi_peak_cells[nc], :,
                                                 i_con] == max_resps[
                                                     multi_peak_cells[nc]]
            cell_same_maxes = np.argwhere(is_same_as_max)[:, 0]
            pref_dir_mat[cell_same_maxes, i_con] += 1.0 / len(cell_same_maxes)

        #one peak cells
        one_peak_cells = np.argwhere(num_same_max == 1)[:, 0]
        pref_dir_at_con = np.argmax(condition_responses[one_peak_cells, :,
                                                        i_con],
                                    axis=1)
        for i_dir, direction in enumerate(directions):
            pref_dir_mat[i_dir, i_con] += np.sum(pref_dir_at_con == i_dir)

    return pref_dir_mat
예제 #5
0
def make_radial_plot_legend(savepath,
                            legend_savename='contrast_vector_legend.svg'):

    directions, contrasts = grating_params()

    contrast_colors = get_contrast_colors()

    plt.figure(figsize=(2, 2))
    ax = plt.subplot(111)
    for i_con, contrast in enumerate(contrasts[::-1]):
        ax.arrow(0,
                 0.07 + i_con / 6.0,
                 0.2,
                 0,
                 color=contrast_colors[i_con],
                 linewidth=2.0)
        ax.text(0.3,
                0.07 + i_con / 6.0,
                str(int(100 * contrast)) + '% contrast',
                fontsize=10,
                verticalalignment='center',
                horizontalalignment='left')
    plt.axis('off')
    plt.savefig(savepath + legend_savename, format='svg')
    plt.close()
예제 #6
0
def plot_tuning_split_by_run_state(df, savepath):

    running_threshold = 1.0  # cm/s
    directions, contrasts = grating_params()

    MIN_SESSIONS = 3
    MIN_CELLS = 3  #per session

    areas, cres = dataset_params()
    for area in areas:
        for cre in cres:

            session_IDs = get_sessions(df, area, cre)
            num_sessions = len(session_IDs)

            if num_sessions >= MIN_SESSIONS:

                curve_dict = {}
                num_sessions_included = 0
                for i_session, session_ID in enumerate(session_IDs):

                    sweep_table = load_sweep_table(savepath, session_ID)
                    mse = load_mean_sweep_events(savepath, session_ID)
                    condition_responses, blank_responses = compute_mean_condition_responses(
                        sweep_table, mse)

                    p_all = chi_square_all_conditions(sweep_table, mse,
                                                      session_ID, savepath)
                    all_idx = np.argwhere(p_all < SIG_THRESH)[:, 0]

                    mean_sweep_running = load_mean_sweep_running(
                        session_ID, savepath)
                    is_run = mean_sweep_running >= running_threshold

                    run_responses, stat_responses, run_blank, stat_blank = condition_response_running(
                        sweep_table, mse, is_run)

                    condition_responses = center_direction_zero(
                        condition_responses)
                    run_responses = center_direction_zero(run_responses)
                    stat_responses = center_direction_zero(stat_responses)

                    peak_dir, __ = get_peak_conditions(condition_responses)

                    if len(all_idx) >= MIN_CELLS:
                        curve_dict = populate_curve_dict(
                            curve_dict, run_responses, run_blank, all_idx,
                            'all_run', peak_dir)
                        curve_dict = populate_curve_dict(
                            curve_dict, stat_responses, stat_blank, all_idx,
                            'all_stat', peak_dir)
                        num_sessions_included += 1

                if num_sessions_included >= MIN_SESSIONS:
                    plot_from_curve_dict(curve_dict, 'all', area, cre,
                                         num_sessions_included, savepath)
예제 #7
0
def get_contrast_colors():

    directions, contrasts = grating_params()

    contrast_colors = []
    cmap = matplotlib.cm.get_cmap('plasma')
    for i in range(len(contrasts)):
        color_frac = i / float(len(contrasts))
        contrast_color = cmap(color_frac)
        contrast_colors.append(contrast_color)

    return contrast_colors
예제 #8
0
def plot_SbC_stats(df, savepath):

    SbC_THRESH = 0.05

    cre_colors = get_cre_colors()
    directions, contrasts = grating_params()

    areas, cres = dataset_params()
    percent_SbC = []
    labels = []
    colors = []
    sample_size = []
    for area in areas:
        for cre in cres:

            session_IDs = get_sessions(df, area, cre)

            if len(session_IDs) > 0:

                num_cells = 0
                num_SbC = 0
                for session_ID in session_IDs:
                    SbC_pval = test_SbC(session_ID, savepath)
                    num_cells += len(SbC_pval)
                    num_SbC += (SbC_pval < SbC_THRESH).sum()

                labels.append(shorthand(cre))
                colors.append(cre_colors[cre])
                percent_SbC.append(100.0 * num_SbC / num_cells)
                sample_size.append(num_cells)

    plt.figure(figsize=(6, 4))
    ax = plt.subplot(111)
    for x, group in enumerate(labels):
        ax.bar(x, percent_SbC[x], color=colors[x])
        ax.text(x,
                max(percent_SbC[x], 5) + 1,
                '(' + str(sample_size[x]) + ')',
                horizontalalignment='center',
                fontsize=8)
    ax.plot([-1, len(labels)], [100 * SbC_THRESH, 100 * SbC_THRESH],
            '--k',
            linewidth=2.0)
    ax.set_ylim(0, 30)
    ax.set_xlim(-1, 14)
    ax.set_xticks(np.arange(len(labels)))
    ax.set_xticklabels(labels, fontsize=10, rotation=45)
    ax.set_ylabel('% Suppressed by Contrast', fontsize=14)
    ax.spines['right'].set_visible(False)
    ax.spines['top'].set_visible(False)
    plt.savefig(savepath + 'SbC_stats.svg', format='svg')
    plt.close()
def pool_sessions(session_IDs,pool_name,savepath,scale='blank_subtracted_NLL'):
    
    if os.path.isfile(savepath+pool_name+'_condition_responses_'+scale+'.npy'):
    
        pooled_condition_responses = np.load(savepath+pool_name+'_condition_responses_'+scale+'.npy')
        pooled_blank_responses = np.load(savepath+pool_name+'_blank_responses_'+scale+'.npy')
        p_vals_all = np.load(savepath+pool_name+'_chisq_all.npy')
    
    else:
        
        print('Pooling sessions for ' + pool_name)
        
        directions, contrasts = grating_params()
        
        MAX_CELLS = 5000
        pooled_condition_responses = np.zeros((MAX_CELLS,len(directions),len(contrasts)))
        pooled_blank_responses = np.zeros((MAX_CELLS,))
        p_vals_all = np.zeros((MAX_CELLS,))
        curr_cell = 0
        for session_ID in session_IDs:
            
            print(str(session_ID))
            
            mse = load_mean_sweep_events(savepath,session_ID)
            sweep_table = load_sweep_table(savepath,session_ID)
    
            if scale=='event':
                condition_responses, blank_responses = compute_mean_condition_responses(sweep_table,mse)
            elif scale=='blank_subtracted_NLL':
                condition_responses, blank_responses = compute_blank_subtracted_NLL(session_ID,savepath)
                
            p_all = chi_square_all_conditions(sweep_table,mse,session_ID,savepath)

            session_cells = len(p_all)
            
            pooled_condition_responses[curr_cell:(curr_cell+session_cells)] = condition_responses
            pooled_blank_responses[curr_cell:(curr_cell+session_cells)] = blank_responses
            p_vals_all[curr_cell:(curr_cell+session_cells)] = p_all
            curr_cell += session_cells
            
        pooled_condition_responses = pooled_condition_responses[:curr_cell]
        pooled_blank_responses = pooled_blank_responses[:curr_cell]
        p_vals_all = p_vals_all[:curr_cell]
        
        np.save(savepath+pool_name+'_condition_responses_'+scale+'.npy',pooled_condition_responses)
        np.save(savepath+pool_name+'_blank_responses_'+scale+'.npy',pooled_blank_responses)
        np.save(savepath+pool_name+'_chisq_all.npy',p_vals_all)
    
        print('Done.')
    
    return pooled_condition_responses, pooled_blank_responses, p_vals_all
def compute_trials_per_condition(sweep_table):
    
    directions, contrasts = grating_params()
    trials_per_condition = np.zeros((len(directions),len(contrasts)),dtype=np.int)
    for i_dir,direction in enumerate(directions):
        is_direction = sweep_table['Ori'] == direction
        for i_con,contrast in enumerate(contrasts):
            is_contrast = sweep_table['Contrast'] == contrast
            is_condition = (is_direction & is_contrast).values
            trials_per_condition[i_dir,i_con] = is_condition.sum()

    num_blanks = np.isnan(sweep_table['Ori'].values).sum()

    return trials_per_condition, num_blanks
def construct_session_Xy(sweep_table, is_run, mean_sweep_events):

    directions, contrasts = grating_params()

    num_dir = len(directions)
    num_con = len(contrasts)
    (num_sweeps, num_cells) = np.shape(mean_sweep_events)

    num_vars = 2 * (1 + num_dir + num_con + num_dir * num_con) + 2
    num_conditions = 2 * (1 + num_dir * num_con)

    X = np.zeros((num_conditions, num_vars), dtype=np.bool)  #[0-7 dir,
    # 8-13 con,
    # 14-61 dirXcon,
    # 62 blank,
    # 63-70 dirXrun,
    # 71-76 conXrun,
    # 77-124 dirXconXrun,
    # 125 blankXrun,
    # 126 run,
    # 127 const]
    y = np.zeros((num_conditions, num_cells))

    i_condition = 0
    for run_state in [False, True]:

        blank_resp = condition_y_separate(None, None, run_state, sweep_table,
                                          is_run, mean_sweep_events)
        X[i_condition] = condition_vars(None, None, run_state)
        y[i_condition] = blank_resp
        i_condition += 1

        for i_dir, direction in enumerate(directions):
            for i_con, contrast in enumerate(contrasts):
                this_resp = condition_y_separate(direction, contrast,
                                                 run_state, sweep_table,
                                                 is_run, mean_sweep_events)

                # only include cells that we have a reliable measure of blank response
                #this_resp = np.where(np.isfinite(blank_resp),this_resp,np.NaN)

                X[i_condition] = condition_vars(direction, contrast, run_state)
                y[i_condition] = this_resp
                i_condition += 1

    return X.astype(np.float), y
def plot_y(X, y, area, cre, savename, savepath):

    (num_conditions, num_params) = np.shape(X)

    directions, contrasts = grating_params()

    stat_resp, run_resp, stat_blank_resp, run_blank_resp = extract_tuning_curves(
        y)

    stat_resp = center_direction_zero(
        stat_resp.reshape(1, len(directions), len(contrasts)))[0]
    run_resp = center_direction_zero(
        run_resp.reshape(1, len(directions), len(contrasts)))[0]

    plot_pooled_mat(run_resp - run_blank_resp, area, cre,
                    'GLM_run_' + savename, savepath)
    plot_pooled_mat(stat_resp - stat_blank_resp, area, cre,
                    'GLM_stat_' + savename, savepath)
def compute_SEM_condition_responses(sweep_table,mean_sweep_events):
    
    (num_sweeps,num_cells) = np.shape(mean_sweep_events) 
    
    directions, contrasts = grating_params()
    
    condition_responses = np.zeros((num_cells,len(directions),len(contrasts)))
    for i_dir,direction in enumerate(directions):
        is_direction = sweep_table['Ori'] == direction
        for i_con,contrast in enumerate(contrasts):
            is_contrast = sweep_table['Contrast'] == contrast
            is_condition = (is_direction & is_contrast).values
            
            condition_responses[:,i_dir,i_con] = np.std(mean_sweep_events[is_condition],axis=0)/np.sqrt(float(is_condition.sum()))
            
    is_blank = np.isnan(sweep_table['Ori'].values)
    blank_sweep_responses = np.std(mean_sweep_events[is_blank],axis=0)/np.sqrt(float(is_blank.sum()))
            
    return condition_responses, blank_sweep_responses 
예제 #14
0
def condition_response_running(sweep_table, mean_sweep_events, is_run):

    MIN_SWEEPS = 4
    directions, contrasts = grating_params()
    (num_sweeps, num_cells) = np.shape(mean_sweep_events)

    run_resps = np.zeros((num_cells, len(directions), len(contrasts)))
    stat_resps = np.zeros((num_cells, len(directions), len(contrasts)))
    run_blank_resps = np.zeros((num_cells, ))
    stat_blank_resps = np.zeros((num_cells, ))

    is_blank = sweep_table['Contrast'].isnull().values
    run_blank_sweeps = np.argwhere(is_blank & is_run)[:, 0]
    stat_blank_sweeps = np.argwhere(is_blank & ~is_run)[:, 0]
    if (len(run_blank_sweeps) >= MIN_SWEEPS) and (len(stat_blank_sweeps) >=
                                                  MIN_SWEEPS):
        run_blank_resps = mean_sweep_events[run_blank_sweeps].mean(axis=0)
        stat_blank_resps = mean_sweep_events[stat_blank_sweeps].mean(axis=0)
    else:
        run_blank_resps[:] = np.NaN
        stat_blank_resps[:] = np.NaN

    for i_dir, direction in enumerate(directions):
        is_direction = sweep_table['Ori'].values == direction
        for i_con, contrast in enumerate(contrasts):
            is_contrast = sweep_table['Contrast'].values == contrast

            run_sweeps = np.argwhere(is_direction & is_contrast & is_run)[:, 0]
            stat_sweeps = np.argwhere(is_direction & is_contrast & ~is_run)[:,
                                                                            0]

            if (len(run_sweeps) >= MIN_SWEEPS) and (len(stat_sweeps) >=
                                                    MIN_SWEEPS):
                run_resps[:, i_dir,
                          i_con] = mean_sweep_events[run_sweeps].mean(axis=0)
                stat_resps[:, i_dir,
                           i_con] = mean_sweep_events[stat_sweeps].mean(axis=0)
            else:
                run_resps[:, i_dir, i_con] = np.NaN
                stat_resps[:, i_dir, i_con] = np.NaN

    return run_resps, stat_resps, run_blank_resps, stat_blank_resps
def get_condition_rows():

    directions, contrasts = grating_params()

    condition_rows = np.zeros((2, len(directions), len(contrasts)),
                              dtype=np.int)
    blank_rows = np.zeros((2, ), dtype=np.int)
    i_condition = 0
    for run_state in [0, 1]:

        #blank sweeps
        blank_rows[run_state] = i_condition
        i_condition += 1

        for i_dir, direction in enumerate(directions):
            for i_con, contrast in enumerate(contrasts):
                condition_rows[run_state, i_dir, i_con] = i_condition
                i_condition += 1

    return condition_rows, blank_rows
def calc_OSI(condition_responses):
    
    (num_cells,num_directions,num_contrasts) = np.shape(condition_responses)
    peak_dir, peak_con = get_peak_conditions(condition_responses)
    
    directions, contrasts = grating_params()
    radians_per_degree = np.pi/180.0
    x_comp = np.cos(2.0*directions*radians_per_degree)
    y_comp = np.sin(2.0*directions*radians_per_degree)
    
    OSI = np.zeros((num_cells,))
    for nc in range(num_cells):
        cell_resp = condition_responses[nc,:,peak_con[nc]]
        normalized_resp = cell_resp / cell_resp.sum()

        x_proj = normalized_resp * x_comp
        y_proj = normalized_resp * y_comp
    
        OSI[nc] = np.sqrt(x_proj.sum()**2 + y_proj.sum()**2)
        
    return OSI
def plot_pooled_mat(pooled_mat, area, cre, pass_str, savepath):

    max_resp = 80.0
    cre_colors = get_cre_colors()
    x_tick_labels = ['-135', '-90', '-45', '0', '45', '90', '135', '180']

    directions, contrasts = grating_params()

    plt.figure(figsize=(4.2, 4))
    ax = plt.subplot(111)

    current_cmap = matplotlib.cm.get_cmap(name='RdBu_r')
    current_cmap.set_bad(color=[0.8, 0.8, 0.8])
    im = ax.imshow(pooled_mat.T,
                   vmin=-max_resp,
                   vmax=max_resp,
                   interpolation='nearest',
                   aspect='auto',
                   cmap='RdBu_r',
                   origin='lower')
    ax.set_xlabel('Direction (deg)', fontsize=14)
    ax.set_ylabel('Contrast (%)', fontsize=14)
    ax.set_yticks(np.arange(len(contrasts)))
    ax.set_yticklabels([str(int(100 * x)) for x in contrasts], fontsize=10)
    ax.set_xticks(np.arange(len(directions)))
    ax.set_xticklabels(x_tick_labels, fontsize=10)
    ax.set_title(shorthand(cre) + ' population',
                 fontsize=16,
                 color=cre_colors[cre])
    cbar = plt.colorbar(
        im,
        ax=ax,
        ticks=[-max_resp, -max_resp / 2.0, 0.0, max_resp / 2.0, max_resp])
    cbar.set_label('Event magnitude per second (%), blank subtracted',
                   rotation=270,
                   labelpad=15.0)
    plt.savefig(savepath + shorthand(area) + '_' + shorthand(cre) + '_' +
                pass_str + '_summed_tuning.svg',
                format='svg')
    plt.close()
예제 #18
0
def plot_direction_vector_sum_by_contrast(df, savepath):

    areas, cres = dataset_params()
    directions, contrasts = grating_params()

    for area in areas:
        for cre in cres:
            session_IDs = get_sessions(df, area, cre)

            if len(session_IDs) > 0:

                resp, blank, p_all = pool_sessions(session_IDs,
                                                   area + '_' + cre,
                                                   savepath,
                                                   scale='event')
                sig_resp = resp[p_all < SIG_THRESH]

                pref_dir_mat = calc_pref_direction_dist_by_contrast(sig_resp)
                pref_dir_mat = pref_dir_mat / np.sum(
                    pref_dir_mat, axis=0, keepdims=True)

                resultant_mag = []
                resultant_theta = []
                for i_con, contrast in enumerate(contrasts):
                    mag, theta = calc_vector_sum(pref_dir_mat[:, i_con])
                    resultant_mag.append(mag)
                    resultant_theta.append(theta)

                #bootstrap CI for distribution at 5% contrast
                num_cells = len(sig_resp)
                uniform_LB, uniform_UB = uniform_direction_vector_sum(
                    num_cells)

                radial_direction_figure(
                    np.zeros((len(directions), )), np.zeros(
                        (len(directions), )), resultant_mag, resultant_theta,
                    uniform_LB, uniform_UB, cre, num_cells,
                    shorthand(area) + '_' + shorthand(cre) + '_combined',
                    savepath)
def plot_rasters_across_contrasts(sweep_table, sweep_events, cell_idx,
                                  session_ID, cre, savepath):

    directions, contrasts = cu.grating_params()

    plt.figure(figsize=(16, 16))

    cell_max = get_max_event_magnitude(sweep_events, cell_idx)

    direction_str = ['0', '45', '90', '135', '180', '-135', '-90', '-45']
    dir_shift = 3

    for i_con, contrast in enumerate(contrasts):

        for i_dir, direction in enumerate(directions):

            #shift 0-degrees to center
            plot_dir = np.mod(i_dir + dir_shift, len(directions))

            ax = plt.subplot(len(directions), len(contrasts),
                             1 + len(contrasts) * plot_dir + i_con)

            plot_single_raster(ax, sweep_table, sweep_events, cell_idx,
                               contrast, direction, cell_max)

            if plot_dir == 7:
                ax.set_xlabel('Time from stimulus onset (s)')
            if plot_dir == 0:
                ax.set_title(str(int(100 * contrast)) + '% Contrast',
                             fontsize=14)

            if i_con == 0:
                ax.set_ylabel(direction_str[i_dir], fontsize=14)

    plt.tight_layout()
    plt.savefig(savepath + cre + '_' + str(session_ID) + '_cell_' +
                str(cell_idx) + '_rasters.png',
                dpi=300)
    plt.close()
예제 #20
0
def uniform_direction_vector_sum(num_cells, num_shuffles=1000, CI_range=0.9):
    #calculates the bounds of confidence interval for a null population with a uniform distribution
    # of direction preferences

    directions, contrasts = grating_params()

    vector_sum = []
    for ns in range(num_shuffles):
        uniform_directions = directions[np.random.choice(len(directions),
                                                         size=num_cells)]
        magnitude, __ = calc_resultant(uniform_directions)
        vector_sum.append(magnitude)
    vector_sum = np.array(vector_sum)

    LB_idx = int(num_shuffles * (1.0 - CI_range) / 2.0)
    UB_idx = int(num_shuffles * (1.0 - (1.0 - CI_range) / 2.0))
    sorted_shuffles = np.sort(vector_sum)

    CI_UB = sorted_shuffles[UB_idx]
    CI_LB = sorted_shuffles[LB_idx]

    return CI_LB, CI_UB
예제 #21
0
def calc_vector_sum(fraction_prefer_directions):

    directions, contrasts = grating_params()

    x_coor = []
    y_coor = []
    for i_dir, direction in enumerate(directions):
        direction_magnitude = fraction_prefer_directions[i_dir]
        x_coor.append(direction_magnitude * np.cos(-np.pi * direction / 180.0))
        y_coor.append(direction_magnitude * np.sin(-np.pi * direction / 180.0))
    x_coor = np.array(x_coor)
    y_coor = np.array(y_coor)

    resultant_x = x_coor.sum()
    resultant_y = y_coor.sum()
    magnitude = np.sqrt(resultant_x**2 + resultant_y**2)

    if resultant_x == 0.0:
        ratio = 1.0 * np.sign(resultant_y)
    else:
        ratio = resultant_y / resultant_x
    theta = -np.arctan(ratio)

    return magnitude, theta
def plot_peak_response_distribution(run_aligned_pooled_low,
                                    stat_aligned_pooled_low,
                                    run_aligned_pooled_high,
                                    stat_aligned_pooled_high, area, cre,
                                    savename, savepath):

    directions, contrasts = grating_params()

    plt.figure(figsize=(7, 4))
    ax = plt.subplot(111)

    MAX_CELLS = 15000

    cre_colors = get_cre_colors()

    resp_dict = {}

    BLANK_IDX = 0
    resp_dict = add_group_to_dict(resp_dict, run_aligned_pooled_low, BLANK_IDX,
                                  'run blank')
    resp_dict = add_group_to_dict(resp_dict, stat_aligned_pooled_low,
                                  BLANK_IDX, 'stat blank')

    directions = [-135, -90, -45, 0, 45, 90, 135, 180]
    contrasts = [0.05, 0.8]
    for run_state in ['run', 'stat']:
        for i_con, contrast in enumerate(contrasts):

            if run_state == 'run' and contrast == 0.05:
                resps = run_aligned_pooled_low
            elif run_state == 'run' and contrast == 0.8:
                resps = run_aligned_pooled_high
            elif run_state == 'stat' and contrast == 0.05:
                resps = stat_aligned_pooled_low
            else:
                resps = stat_aligned_pooled_high

            for i_dir, direction in enumerate(directions):
                group_name = run_state + ' ' + str(direction) + ' ' + str(
                    int(100 * contrast)) + '%'
                resp_dict = add_group_to_dict(resp_dict, resps, 1 + i_dir,
                                              group_name)

    plot_order = [('space1', ''), ('run blank', ''), ('stat blank', ''),
                  ('space2', '')]
    curr_space = 3
    for run_state in ['run', 'stat']:
        for i_con, contrast in enumerate(contrasts):
            for i_dir, direction in enumerate(directions):
                plot_order.append((run_state + ' ' + str(direction) + ' ' +
                                   str(int(100 * contrast)) + '%', ''))
            plot_order.append(('space' + str(curr_space), ''))
            curr_space += 1

    colors = ['#9f9f9f']  #blanks
    for i in range(len(plot_order)):
        colors.append(cre_colors[cre])
    cre_palette = sns.color_palette(colors)

    resp_df = pd.DataFrame(np.zeros((MAX_CELLS, 3)),
                           columns=('Response to Preferred Direction',
                                    'cell_type', 'cre'))
    curr_cell = 0
    labels = []
    x_pos = []
    dist = []
    dir_idx = 0
    for line, (group, cre_name) in enumerate(plot_order):
        if group.find('space') == -1:
            resp_mag = resp_dict[group]
            resp_mag = resp_mag[np.argwhere(np.isfinite(resp_mag))[:, 0]]
            num_cells = len(resp_mag)
            resp_df['Response to Preferred Direction'][curr_cell:(
                curr_cell + num_cells)] = resp_mag
            resp_df['cre'][curr_cell:(curr_cell + num_cells)] = cre_name
            resp_df['cell_type'][curr_cell:(curr_cell + num_cells)] = group
            curr_cell += num_cells
            x_pos.append(line)
            dist.append(resp_mag)

            if group.find('blank') != -1:
                if group.find('run') != -1:
                    labels.append('run')
                else:
                    labels.append('stat')
            else:
                labels.append(str(directions[dir_idx]))
                dir_idx += 1
                if dir_idx == len(directions):
                    dir_idx = 0

        else:
            resp_df['Response to Preferred Direction'][curr_cell] = np.NaN
            resp_df['cre'][curr_cell] = 'blank'
            resp_df['cell_type'][curr_cell] = group
            curr_cell += 1

    resp_df = resp_df.drop(index=np.arange(curr_cell, MAX_CELLS))

    ax = sns.swarmplot(x='cell_type',
                       y='Response to Preferred Direction',
                       hue='cre',
                       size=1.0,
                       palette=cre_palette,
                       data=resp_df)

    ax.set_xticks(np.array(x_pos))
    ax.set_xticklabels(labels, fontsize=4.5, rotation=0)
    ax.legend_.remove()

    for i, d in enumerate(dist):
        plot_quartiles(ax, d, x_pos[i])

    ax.set_ylim(-20, 400)
    ax.set_ylabel('Event magnitude per second (%)', fontsize=12)
    ax.set_xlabel(
        'Blank     run 5% contrast        run 80% contrast       stat 5% contrast       stat 80% contrast ',
        fontsize=9)
    ax.spines['right'].set_visible(False)
    ax.spines['top'].set_visible(False)
    plt.tight_layout()
    plt.savefig(savepath + shorthand(area) + '_' + shorthand(cre) + '_' +
                savename + '_cell_response_distribution.svg',
                format='svg')
    plt.close()
def plot_param_CI(res, area, cre, savepath, PARAM_TOL=1E-2, save_format='svg'):

    model_params = np.array(res.params)
    terms = unpack_params(model_params)

    CI = np.array(res.conf_int(alpha=0.05))
    CI_lb = unpack_params(CI[:, 0])
    CI_ub = unpack_params(CI[:, 1])

    plot_order = [
        'blank', 'run', 'dir', 'con', 'dirXrun', 'conXrun', 'dirXcon',
        'dirXconXrun'
    ]

    directions, contrasts = grating_params()
    directions = [-135, -90, -45, 0, 45, 90, 135, 180]

    x_labels = {}
    x_labels['blank'] = ['blank']
    x_labels['blankXrun'] = ['blank X run']
    x_labels['run'] = ['run']
    x_labels['dir'] = [str(x) for x in directions]
    x_labels['con'] = [str(int(100 * x)) for x in contrasts]
    x_labels['dirXrun'] = [str(x) for x in directions]
    x_labels['conXrun'] = [str(int(100 * x)) for x in contrasts]

    plt.figure(figsize=(20, 4.5))

    savename = shorthand(area) + '_' + shorthand(cre)
    cre_colors = get_cre_colors()

    ax = plt.subplot(111)

    curr_x = 0

    x_ticks = []
    x_ticklabels = []
    for i, param_name in enumerate(plot_order):

        param_means = terms[param_name]
        param_CI_lb = CI_lb[param_name]
        param_CI_ub = CI_ub[param_name]

        #center directions on zero
        if param_name == 'dirXcon' or param_name == 'dirXconXrun' or param_name == 'dir' or param_name == 'dirXrun':
            param_means = center_dir_on_zero(param_means)
            param_CI_lb = center_dir_on_zero(param_CI_lb)
            param_CI_ub = center_dir_on_zero(param_CI_ub)

        #handle parameters that are not 1D arrays
        if type(param_means) == np.float64:
            param_means = np.array([param_means])
            param_CI_lb = np.array([param_CI_lb])
            param_CI_ub = np.array([param_CI_ub])
        elif param_name == 'dirXcon' or param_name == 'dirXconXrun':
            param_means = param_means.flatten()
            param_CI_lb = param_CI_lb.flatten()
            param_CI_ub = param_CI_ub.flatten()

        param_errs = CI_to_errorbars(param_means, param_CI_lb, param_CI_ub)

        num_params = np.shape(param_errs)[1]

        #for dirXcon terms, only plot non-zero values
        if param_name == 'dirXcon' or param_name == 'dirXconXrun':
            non_zero_idx = np.argwhere((param_CI_ub < -PARAM_TOL)
                                       | (param_CI_lb > PARAM_TOL))[:, 0]
            num_params = len(non_zero_idx)
            cond_tick_labels = []
            if num_params > 0:
                param_means = param_means[non_zero_idx]
                param_errs = param_errs[:, non_zero_idx]

                for i_cond, idx in enumerate(non_zero_idx):
                    i_dir = int(idx / 6)
                    i_con = int(idx % 6)
                    cond_tick_labels.append(
                        str(int(100 * contrasts[i_con])) + '%,' +
                        str(int(directions[i_dir])))

            # pad params to make all plots equal size
            ticks_to_plot = 6
            num_to_pad = ticks_to_plot - num_params
            for i_pad in range(num_to_pad):
                cond_tick_labels.append('')

            x_labels[param_name] = cond_tick_labels
            x_values = np.arange(curr_x, curr_x + 2 * num_params,
                                 2)  #double spacing
        else:
            ticks_to_plot = num_params
            x_values = np.arange(curr_x, curr_x + num_params)

        if num_params > 0:

            ax.errorbar(x_values,
                        param_means,
                        yerr=param_errs,
                        fmt='o',
                        color=cre_colors[cre],
                        linewidth=3,
                        capsize=5,
                        elinewidth=2,
                        markeredgewidth=2)

            for i_x, x in enumerate(x_values):
                x_ticks.append(x)
                x_ticklabels.append(x_labels[param_name][i_x])

        curr_x += ticks_to_plot + 1

    ax.plot([-1, curr_x], [0, 0], 'k', linewidth=1.0)

    ax.set_ylabel('Weight', fontsize=14)

    ax.set_xlim([-1, curr_x])
    ax.set_xticks(x_ticks)
    ax.set_xticklabels(x_ticklabels)

    y_max = 2.5
    y_min = -1.5
    y_ticks = [-1, 0, 1, 2]

    ax.set_yticks(y_ticks)
    ax.set_yticklabels([str(int(y)) for y in y_ticks])
    ax.set_ylim([y_min, y_max])

    ax.spines['right'].set_visible(False)
    ax.spines['top'].set_visible(False)

    if save_format == 'svg':
        plt.savefig(savepath + savename + '_param_CI.svg', format='svg')
    else:
        plt.savefig(savepath + savename + '_param_CI.png', dpi=300)

    plt.close()
def plot_param_heatmaps(model_params,
                        cre,
                        savename,
                        savepath,
                        save_format='svg'):

    terms = unpack_params(model_params)

    w_max = 2.0

    directions, contrasts = grating_params()
    num_dir = len(directions)
    num_con = len(contrasts)

    title_fontsize = 10
    tick_fontsize = 10

    dir_ticks = np.arange(num_dir)
    dir_ticklabels = ['', '-90', '', '0', '', '90', '', '180']

    plt.figure(figsize=(20, 6))
    ax1 = plt.subplot(194)
    ax1.imshow(np.tile(center_dir_on_zero(terms['dir']), reps=(num_con, 1)),
               interpolation='none',
               origin='lower',
               cmap='RdBu_r',
               vmin=-w_max,
               vmax=w_max)
    ax1.set_xticks(dir_ticks)
    ax1.set_xticklabels(dir_ticklabels, fontsize=tick_fontsize)
    ax1.set_yticks([])
    ax1.set_title('Direction', fontsize=title_fontsize)

    ax2 = plt.subplot(195)
    ax2.imshow(np.tile(terms['con'], reps=(num_dir, 1)).T,
               interpolation='none',
               origin='lower',
               cmap='RdBu_r',
               vmin=-w_max,
               vmax=w_max)
    ax2.set_yticks(np.arange(num_con))
    ax2.set_yticklabels([str(int(100 * x)) for x in contrasts],
                        fontsize=tick_fontsize)
    ax2.set_xticks([])
    ax2.set_title('Contrast', fontsize=title_fontsize)

    ax3 = plt.subplot(198)
    ax3.imshow(center_dir_on_zero(terms['dirXcon']).T,
               interpolation='none',
               cmap='RdBu_r',
               origin='lower',
               vmin=-w_max,
               vmax=w_max)
    ax3.set_xticks(dir_ticks)
    ax3.set_xticklabels(dir_ticklabels, fontsize=tick_fontsize)
    ax3.set_yticks(np.arange(num_con))
    ax3.set_yticklabels([str(int(100 * x)) for x in contrasts],
                        fontsize=tick_fontsize)
    ax3.set_title('Direction X Contrast', fontsize=title_fontsize)

    ax4 = plt.subplot(197)
    ax4.imshow(np.tile(terms['conXrun'], reps=(num_dir, 1)).T,
               interpolation='none',
               origin='lower',
               cmap='RdBu_r',
               vmin=-w_max,
               vmax=w_max)
    ax4.set_xticks([])
    ax4.set_yticks(np.arange(num_con))
    ax4.set_yticklabels([str(int(100 * x)) for x in contrasts],
                        fontsize=tick_fontsize)
    ax4.set_title('Run X Contrast', fontsize=title_fontsize)

    ax5 = plt.subplot(196)
    ax5.imshow(np.tile(center_dir_on_zero(terms['dirXrun']),
                       reps=(num_con, 1)),
               interpolation='none',
               origin='lower',
               cmap='RdBu_r',
               vmin=-w_max,
               vmax=w_max)
    ax5.set_xticks(dir_ticks)
    ax5.set_xticklabels(dir_ticklabels, fontsize=tick_fontsize)
    ax5.set_title('Run X Direction', fontsize=title_fontsize)
    ax5.set_yticks([])

    ax6 = plt.subplot(193)
    ax6.imshow(np.tile(terms['run'], reps=(num_con, num_dir)),
               interpolation='none',
               origin='lower',
               cmap='RdBu_r',
               vmin=-w_max,
               vmax=w_max)
    ax6.set_xticks([])
    ax6.set_title('Run', fontsize=title_fontsize)
    ax6.set_yticks([])

    ax7 = plt.subplot(199)
    ax7.imshow(center_dir_on_zero(terms['dirXconXrun']).T,
               interpolation='none',
               cmap='RdBu_r',
               origin='lower',
               vmin=-w_max,
               vmax=w_max)
    ax7.set_xticks(dir_ticks)
    ax7.set_xticklabels(dir_ticklabels, fontsize=tick_fontsize)
    ax7.set_yticks(np.arange(num_con))
    ax7.set_yticklabels([str(int(100 * x)) for x in contrasts],
                        fontsize=tick_fontsize)
    ax7.set_title('Run X Direction X Contrast', fontsize=title_fontsize)

    ax8 = plt.subplot(192)
    blanks = np.tile(terms['blank'], reps=(num_con, num_dir))
    ax8.imshow(blanks,
               interpolation='none',
               origin='lower',
               cmap='RdBu_r',
               vmin=-w_max,
               vmax=w_max)
    ax8.set_xticks([])
    ax8.set_title('Blank', fontsize=title_fontsize)
    ax8.set_yticks([])

    ax9 = plt.subplot(191)
    im = ax9.imshow(np.tile(terms['const'], reps=(num_con, num_dir)),
                    interpolation='none',
                    origin='lower',
                    cmap='RdBu_r',
                    vmin=-w_max,
                    vmax=w_max)
    ax9.set_xticks([])
    ax9.set_title('Constant', fontsize=title_fontsize)
    ax9.set_yticks([])

    cbar = plt.colorbar(im, ax=ax9, ticks=[-2, -1, 0, 1, 2])

    if save_format == 'svg':
        plt.savefig(savepath + savename + '_GLM.svg', format='svg')
    else:
        plt.savefig(savepath + savename + '_GLM.png', dpi=300)

    plt.close()
def unpack_params(model_params):

    directions, contrasts = grating_params()

    curr = 0

    terms = {}

    terms['blank'] = model_params[curr]
    curr += 1

    # terms without run interaction
    dir_terms = []
    for i_dir, this_direction in enumerate(directions):
        dir_terms.append(model_params[curr])
        curr += 1
    terms['dir'] = np.array(dir_terms)

    con_terms = []
    for i_con, this_contrast in enumerate(contrasts):
        con_terms.append(model_params[curr])
        curr += 1
    terms['con'] = np.array(con_terms)

    dirXcon_terms = []
    for i_dir, this_direction in enumerate(directions):
        for i_con, this_contrast in enumerate(contrasts):
            dirXcon_terms.append(model_params[curr])
            curr += 1
    terms['dirXcon'] = np.array(dirXcon_terms).reshape(len(directions),
                                                       len(contrasts))

    terms['blankXrun'] = model_params[curr]
    curr += 1

    dirXrun_terms = []
    for i_dir, this_direction in enumerate(directions):
        dirXrun_terms.append(model_params[curr])
        curr += 1
    terms['dirXrun'] = np.array(dirXrun_terms)

    conXrun_terms = []
    for i_con, this_contrast in enumerate(contrasts):
        conXrun_terms.append(model_params[curr])
        curr += 1
    terms['conXrun'] = np.array(conXrun_terms)

    dirXconXrun_terms = []
    for i_dir, this_direction in enumerate(directions):
        for i_con, this_contrast in enumerate(contrasts):
            dirXconXrun_terms.append(model_params[curr])
            curr += 1
    terms['dirXconXrun'] = np.array(dirXconXrun_terms).reshape(
        len(directions), len(contrasts))

    terms['run'] = model_params[curr]
    curr += 1

    terms['const'] = model_params[curr]

    return terms
예제 #26
0
def decode_direction_from_running(df, savepath, save_format='svg'):

    directions, contrasts = grating_params()

    running_dict = {}

    areas, cres = dataset_params()
    for area in ['VISp']:
        for cre in cres:

            celltype = shorthand(area) + ' ' + shorthand(cre)

            session_IDs = get_sessions(df, area, cre)
            num_sessions = len(session_IDs)

            if num_sessions > 0:

                savename = shorthand(area) + '_' + shorthand(
                    cre) + '_running_direction_decoder.npy'
                if os.path.isfile(savepath + savename):
                    #decoder_performance = np.load(savepath+savename)
                    running_performance = np.load(
                        savepath + shorthand(area) + '_' + shorthand(cre) +
                        '_running_direction_decoder.npy')
                else:
                    #decoder_performance = []
                    running_performance = []
                    for i_session, session_ID in enumerate(session_IDs):

                        #mean_sweep_events = load_mean_sweep_events(savepath,session_ID)
                        mean_sweep_running = load_mean_sweep_running(
                            session_ID, savepath)

                        sweep_table = load_sweep_table(savepath, session_ID)

                        #(num_sweeps,num_cells) =  np.shape(mean_sweep_events)

                        is_blank = sweep_table['Ori'].isnull().values
                        blank_sweeps = np.argwhere(is_blank)[:, 0]
                        sweep_directions = sweep_table['Ori'].values

                        sweep_categories = sweep_directions.copy()
                        sweep_categories[blank_sweeps] = 360
                        sweep_categories = sweep_categories.astype(np.int) / 45

                        is_low = sweep_table['Contrast'].values < 0.2
                        sweeps_included = np.argwhere(is_low)[:, 0]

                        sweep_categories = sweep_categories[sweeps_included]
                        #mean_sweep_events = mean_sweep_events[sweeps_included]
                        mean_sweep_running = mean_sweep_running[
                            sweeps_included]

                        #decode front-to-back motion
                        #                        is_front_to_back = (sweep_categories==0) |  (sweep_categories==7)
                        #                        front_to_back_sweeps = np.argwhere(is_front_to_back)[:,0]
                        #                        rest_sweeps = np.argwhere(~is_front_to_back)[:,0]
                        #                        sweep_categories[front_to_back_sweeps] = 0
                        #                        sweep_categories[rest_sweeps] = 1

                        running_performance.append(
                            decode_direction(
                                mean_sweep_running.reshape(
                                    len(sweeps_included), 1),
                                sweep_categories))
                        #for nc in range(num_cells):
                        #decoder_performance.append(decode_direction(mean_sweep_events,sweep_categories))
                    #decoder_performance = np.array(decoder_performance)
                    running_performance = np.array(running_performance)
                    #np.save(savepath+savename,decoder_performance)
                    np.save(
                        savepath + shorthand(area) + '_' + shorthand(cre) +
                        '_running_direction_decoder.npy', running_performance)
                #print celltype + ': ' + str(np.mean(decoder_performance))
                print(celltype + ': ' + str(np.mean(running_performance)))
                running_dict[shorthand(cre)] = running_performance

    cre_colors = get_cre_colors()

    plt.figure(figsize=(6, 4))
    ax = plt.subplot(111)
    ax.plot([-1, 6], [12.5, 12.5], 'k--')
    label_loc = []
    labels = []
    for i, cre in enumerate(cres):
        session_performance = running_dict[shorthand(cre)]
        ax.plot(i * np.ones((len(session_performance), )),
                100.0 * session_performance,
                '.',
                markersize=4.0,
                color=cre_colors[cre])
        ax.plot([i - 0.4, i + 0.4], [
            100.0 * session_performance.mean(),
            100.0 * session_performance.mean()
        ],
                color=cre_colors[cre],
                linewidth=3)
        label_loc.append(i)
        labels.append(shorthand(cre))
    ax.set_xticks(label_loc)
    ax.set_xticklabels(labels, rotation=45, fontsize=10)
    ax.set_ylim(0, 25)
    ax.spines['right'].set_visible(False)
    ax.spines['top'].set_visible(False)
    ax.set_xlim(-1, 14)
    #ax.text(3,20,'Predict direction from running',fontsize=14,horizontalalignment='center')
    ax.set_ylabel('Decoding performance (%)', fontsize=14)

    if save_format == 'svg':
        plt.savefig(savepath + 'running_decoder.svg', format='svg')
    else:
        plt.savefig(savepath + 'running_decoder.png', dpi=300)

    plt.close()
예제 #27
0
def plot_full_waterfall(cells_by_condition,
                        cre,
                        save_name,
                        scale,
                        savepath,
                        do_colorbar=False):

    resp_max = 4.0
    resp_min = -4.0

    directions, contrasts = grating_params()
    num_contrasts = len(contrasts)
    num_directions = len(directions)

    (num_cells, num_conditions) = np.shape(cells_by_condition)

    cre_colors = get_cre_colors()

    plt.figure(figsize=(10, 4))
    ax = plt.subplot(111)
    im = ax.imshow(cells_by_condition,
                   vmin=resp_min,
                   vmax=resp_max,
                   interpolation='nearest',
                   aspect='auto',
                   cmap='RdBu_r')

    #dividing lines between contrasts
    for i_con in range(num_contrasts - 1):
        ax.plot([(i_con + 1) * num_directions - 0.5,
                 (i_con + 1) * num_directions - 0.5], [0, num_cells - 1],
                'k',
                linewidth=2.0)

    ax.set_ylabel(shorthand(cre) + ' cell number',
                  fontsize=14,
                  color=cre_colors[cre],
                  labelpad=-6)
    ax.set_xlabel('Contrast (%)', fontsize=14, labelpad=-5)

    ax.set_xticks(num_directions * np.arange(num_contrasts) +
                  (num_directions / 2) - 0.5)
    ax.set_xticklabels([str(int(100 * x)) for x in contrasts], fontsize=12)

    ax.set_yticks([0, num_cells - 1])
    ax.set_yticklabels(['0', str(num_cells - 1)], fontsize=12)

    if do_colorbar:

        percentile_ticks = [
            0.0001, 0.001, 0.01, 0.1, 0.5, 0.9, 0.99, 0.999, 0.9999
        ]
        NLL_ticks = percentile_to_NLL(percentile_ticks, num_shuffles=200000)

        cbar = plt.colorbar(im,
                            ax=ax,
                            ticks=NLL_ticks,
                            orientation='horizontal')
        cbar.ax.set_xticklabels([str(100 * x) for x in percentile_ticks],
                                fontsize=12)
        cbar.set_label('Response Percentile',
                       fontsize=16,
                       rotation=0,
                       labelpad=15.0)

    plt.tight_layout()
    plt.savefig(savepath + save_name + '_' + scale + '.svg', format='svg')
    plt.close()
def plot_traces_across_contrasts(sweep_table, traces, cell_idx, session_ID,
                                 cre, savepath):

    directions, contrasts = cu.grating_params()

    plt.figure(figsize=(8, 16))

    min_val = 0
    max_val = 0

    for i_con, contrast in enumerate(contrasts):

        for i_dir, direction in enumerate(directions):

            ax = plt.subplot(len(directions), len(contrasts),
                             1 + len(contrasts) * i_dir + i_con)

            ct, t = get_condition_traces(sweep_table,
                                         traces,
                                         cell_idx=cell_idx,
                                         contrast=contrast,
                                         direction=direction)

            #make raster
            (num_sweeps, num_frames) = ct.shape
            for i_sweep in range(num_sweeps):
                ax.plot(t, 100.0 * ct[i_sweep])

            if i_dir == 7:
                ax.set_xlabel('Time from stimulus onset (s)')
            if i_dir == 0:
                ax.set_title(str(int(100 * contrast)) + '% Contrast')

            if i_con == 0:
                ax.set_ylabel(str(int(direction)))
            ax.spines['top'].set_visible(False)
            ax.spines['right'].set_visible(False)

            ax.set_xticks([0, 2])
            ax.set_xticklabels(['0', '2'])

            if min_val > ct.min():
                min_val = ct.min()
            if max_val < ct.max():
                max_val = ct.max()

    for i_con in range(len(contrasts)):
        for i_dir, direction in enumerate(directions):
            ax = plt.subplot(len(directions), len(contrasts),
                             1 + len(contrasts) * i_dir + i_con)
            ax.set_ylim([100 * min_val, 100 * max_val])

            rect = patch.Rectangle((0, 100 * min_val),
                                   2,
                                   100 * (max_val - min_val),
                                   facecolor=(0.8, 0.8, 0.8))
            ax.add_patch(rect)

    plt.tight_layout()
    plt.savefig(savepath + cre + '_' + str(session_ID) + '_cell_' +
                str(cell_idx) + '_traces.png',
                dpi=300)
    plt.close()
예제 #29
0
def radial_direction_figure(x_coor,
                            y_coor,
                            resultant_mag,
                            resultant_theta,
                            CI_LB,
                            CI_UB,
                            cre,
                            num_cells,
                            savename,
                            savepath,
                            max_radius=0.75):

    color = get_cre_colors()[cre]

    directions, contrasts = grating_params()

    unit_circle_x = np.linspace(-1.0, 1.0, 100)
    unit_circle_y = np.sqrt(1.0 - unit_circle_x**2)

    plt.figure(figsize=(4, 4))
    ax = plt.subplot(111)

    outer_CI = Circle((0, 0), CI_UB / max_radius, facecolor=[0.6, 0.6, 0.6])
    inner_CI = Circle((0, 0), CI_LB / max_radius, facecolor=[1.0, 1.0, 1.0])
    ax.add_patch(outer_CI)
    ax.add_patch(inner_CI)

    #spokes
    for i, direction in enumerate(directions):
        ax.plot([0, np.cos(np.pi * direction / 180.0)],
                [0, np.sin(np.pi * direction / 180.0)],
                'k--',
                linewidth=1.0)

    #outer ring
    ax.plot(unit_circle_x, unit_circle_y, 'k', linewidth=2.0)
    ax.plot(unit_circle_x, -unit_circle_y, 'k', linewidth=2.0)

    ax.plot(0.25 * unit_circle_x / max_radius,
            0.25 * unit_circle_y / max_radius,
            '--k',
            linewidth=1.0)
    ax.plot(0.25 * unit_circle_x / max_radius,
            -0.25 * unit_circle_y / max_radius,
            '--k',
            linewidth=1.0)

    ax.plot(0.5 * unit_circle_x / max_radius,
            0.5 * unit_circle_y / max_radius,
            '--k',
            linewidth=1.0)
    ax.plot(0.5 * unit_circle_x / max_radius,
            -0.5 * unit_circle_y / max_radius,
            '--k',
            linewidth=1.0)

    ax.plot(unit_circle_x, unit_circle_y, 'k', linewidth=2.0)
    ax.plot(unit_circle_x, -unit_circle_y, 'k', linewidth=2.0)

    #center
    ax.plot(unit_circle_x / 200.0, unit_circle_y / 200.0, 'k', linewidth=2.0)
    ax.plot(unit_circle_x / 200.0, -unit_circle_y / 200.0, 'k', linewidth=2.0)

    ax.plot(np.array(x_coor) / max_radius,
            np.array(y_coor) / max_radius,
            color=color,
            linewidth=2.0)

    contrast_colors = get_contrast_colors()
    for i, mag in enumerate(resultant_mag[::-1]):
        ax.arrow(0,
                 0,
                 mag * np.cos(-resultant_theta[len(contrasts) - i - 1]) /
                 (max_radius),
                 mag * np.sin(-resultant_theta[len(contrasts) - i - 1]) /
                 (max_radius),
                 color=contrast_colors[i],
                 linewidth=2.0,
                 head_width=0.03)

    #labels
    ax.text(0,
            1.02,
            'U',
            fontsize=12,
            horizontalalignment='center',
            verticalalignment='bottom')
    ax.text(0,
            -1.02,
            'D',
            fontsize=12,
            horizontalalignment='center',
            verticalalignment='top')
    ax.text(1.02,
            0,
            'T',
            fontsize=12,
            verticalalignment='center',
            horizontalalignment='left')
    ax.text(-1.02,
            0,
            'N',
            fontsize=12,
            verticalalignment='center',
            horizontalalignment='right')
    ax.text(-1, 0.99, shorthand(cre), fontsize=16, horizontalalignment='left')
    ax.text(0.73,
            0.99,
            '(n=' + str(num_cells) + ')',
            fontsize=10,
            horizontalalignment='left')

    ax.text(.73,
            -.73,
            '45',
            fontsize=10,
            horizontalalignment='left',
            verticalalignment='top')
    ax.text(-.78,
            -.75,
            '135',
            fontsize=10,
            horizontalalignment='right',
            verticalalignment='top')
    ax.text(-.73,
            .73,
            '-135',
            fontsize=10,
            verticalalignment='bottom',
            horizontalalignment='right')
    ax.text(.73,
            .73,
            '-45',
            fontsize=10,
            verticalalignment='bottom',
            horizontalalignment='left')
    ax.text(.81,
            -.71,
            '$^\circ$',
            fontsize=18,
            horizontalalignment='left',
            verticalalignment='top')
    ax.text(-.69,
            -.73,
            '$^\circ$',
            fontsize=18,
            horizontalalignment='right',
            verticalalignment='top')
    ax.text(-.64,
            .69,
            '$^\circ$',
            fontsize=18,
            verticalalignment='bottom',
            horizontalalignment='right')
    ax.text(.85,
            .69,
            '$^\circ$',
            fontsize=18,
            verticalalignment='bottom',
            horizontalalignment='left')

    ax.set_xlim(-1.2, 1.2)
    ax.set_ylim(-1.2, 1.2)
    plt.axis('equal')
    plt.axis('off')
    plt.savefig(savepath + savename + '_radial_direction_tuning.svg',
                format='svg')
    plt.close()
예제 #30
0
def plot_single_cell_example(df,
                             savepath,
                             cre,
                             example_cell,
                             example_session_idx=0):

    directions, contrasts = grating_params()

    session_IDs = get_sessions(df, 'VISp', cre)
    session_ID = session_IDs[example_session_idx]

    mse = load_mean_sweep_events(savepath, session_ID)
    sweep_table = load_sweep_table(savepath, session_ID)

    condition_responses, __ = compute_mean_condition_responses(
        sweep_table, mse)
    condition_SEM, __ = compute_SEM_condition_responses(sweep_table, mse)
    p_all = chi_square_all_conditions(sweep_table, mse, session_ID, savepath)

    sig_resp = condition_responses[p_all < SIG_THRESH]
    sig_SEM = condition_SEM[p_all < SIG_THRESH]

    #shift zero to center:
    directions = [-135, -90, -45, 0, 45, 90, 135, 180]
    sig_resp = center_direction_zero(sig_resp)
    sig_SEM = center_direction_zero(sig_SEM)

    #full direction by contrast response heatmap
    plt.figure(figsize=(6, 4))
    ax = plt.subplot2grid((5, 5), (0, 0), rowspan=5, colspan=2)
    ax.imshow(sig_resp[example_cell],
              vmin=0.0,
              interpolation='nearest',
              aspect='auto',
              cmap='plasma')
    ax.set_ylabel('Direction (deg)', fontsize=14)
    ax.set_xlabel('Contrast (%)', fontsize=14)
    ax.set_xticks(np.arange(len(contrasts)))
    ax.set_xticklabels([str(int(100 * x)) for x in contrasts], fontsize=10)
    ax.set_yticks(np.arange(len(directions)))
    ax.set_yticklabels([str(x) for x in directions], fontsize=10)

    peak_dir_idx, peak_con_idx = get_peak_conditions(sig_resp)

    #contrast tuning at peak direction
    contrast_means = sig_resp[example_cell, peak_dir_idx[example_cell], :]
    contrast_SEMs = sig_SEM[example_cell, peak_dir_idx[example_cell], :]
    ax = plt.subplot2grid((5, 5), (0, 3), rowspan=2, colspan=2)
    ax.errorbar(np.log(contrasts), contrast_means, contrast_SEMs)
    ax.set_xticks(np.log(contrasts))
    ax.set_xticklabels([str(int(100 * x)) for x in contrasts], fontsize=10)
    ax.set_xlabel('Contrast (%)', fontsize=14)
    ax.set_ylabel('Response', fontsize=14)
    ax.spines['right'].set_visible(False)
    ax.spines['top'].set_visible(False)

    #direction tuning at peak contrast
    direction_means = sig_resp[example_cell, :, peak_con_idx[example_cell]]
    direction_SEMs = sig_SEM[example_cell, :, peak_con_idx[example_cell]]
    ax = plt.subplot2grid((5, 5), (3, 3), rowspan=2, colspan=2)
    ax.errorbar(np.arange(len(directions)), direction_means, direction_SEMs)
    ax.set_xlim(-0.07, 7.07)
    ax.set_xticks(np.arange(len(directions)))
    ax.set_xticklabels([str(x) for x in directions], fontsize=10)
    ax.set_xlabel('Direction (deg)', fontsize=14)
    ax.set_ylabel('Response', fontsize=14)
    ax.spines['right'].set_visible(False)
    ax.spines['top'].set_visible(False)

    plt.tight_layout(w_pad=-5.5, h_pad=0.1)
    plt.savefig(savepath + shorthand(cre) + '_example_cell.svg', format='svg')
    plt.close()