Пример #1
0
def plot_behavior_and_pop_avg(dataset, xlim_seconds=None, save_figure=True):
    if xlim_seconds is None:
        suffix = ''
    elif xlim_seconds[1] < 2000:
        suffix = 'early'
    elif xlim_seconds[1] > -2000:
        suffix = 'late'

    figsize = (15, 8)
    fig, ax = plt.subplots(3, 1, figsize=figsize, sharex=True)
    try:
        ax[0] = plot_behavior_model_weights(dataset,
                                            xlim_seconds=xlim_seconds,
                                            plot_stimuli=True,
                                            ax=ax[0])
    except BaseException:
        print('no behavior model output for', dataset.ophys_experiment_id)
    ax[1] = plot_behavior(dataset.ophys_experiment_id,
                          xlim_seconds=xlim_seconds,
                          plot_stimuli=True,
                          ax=ax[1])
    ax[2] = plot_traces(dataset,
                        include_cell_traces=False,
                        plot_stimuli=True,
                        xlim_seconds=xlim_seconds,
                        ax=ax[2])
    plt.subplots_adjust(wspace=0, hspace=0.1)
    ax[0].set_title(dataset.metadata_string)

    if save_figure:
        save_dir = os.path.abspath(
            os.path.join(loading.get_qc_plots_dir(), 'timeseries_plots'))
        utils.save_figure(fig, figsize, save_dir,
                          'behavior_traces_population_average',
                          dataset.metadata_string + '_' + suffix)
Пример #2
0
def plot_average_metric_value_for_experience_levels_across_containers(df, metric, ylim=None, horizontal=True,
                                                                      save_dir=None, folder=None, suffix='', ax=None):
    """
    Plots the average metric value across experience levels for each cre line in color,
    with individual containers shown as connected gray lines

    :param df: dataframe with columns ['cell_type', 'experience_level', 'ophys_container_id', 'ophys_experiment_id']
                and a column with some metric value to compute the mean of, such as 'mean_response' or 'reliability'
                if 'cell_specimen_id' is included in the dataframe, will average across cells per experiment / container for the plot
    :param ylim: ylimits, in units of metric value provided, to constrain the plot.
    :param save_dir: directory to save figures to. if None, will not save.
    :param folder: sub folder of save_dir to save figures to
    :param suffix: string starting with '_' to append to end of filename of saved plot
    :return:
    """

    experience_levels = np.sort(df.experience_level.unique())
    cell_types = np.sort(df.cell_type.unique())[::-1]

    # get mean value per container
    mean_df = df.groupby(['cell_type', 'experience_level', 'ophys_container_id', 'ophys_experiment_id']).mean()[
        [metric]].reset_index()

    palette = utils.get_experience_level_colors()
    if ax is None:
        format_fig = True
        if horizontal:
            figsize = (10, 4)
            fig, ax = plt.subplots(1, 3, figsize=figsize, sharex=False)
        else:
            figsize = (3.5, 10.5)
            fig, ax = plt.subplots(3, 1, figsize=figsize, sharex=True)
    else:
        format_fig = False
    for i, cell_type in enumerate(cell_types):
        data = mean_df[mean_df.cell_type == cell_type]
        # plot each container as gray lines
        for ophys_container_id in data.ophys_container_id.unique():
            ax[i] = sns.pointplot(data=data[data.ophys_container_id == ophys_container_id], x='experience_level',
                                  y=metric,
                                  color='gray', join=True, markers='.', scale=0.25, errwidth=0.25, ax=ax[i], zorder=500)
        plt.setp(ax[i].collections, alpha=.3)  # for the markers
        plt.setp(ax[i].lines, alpha=.3)
        # plot the population average in color
        ax[i] = sns.pointplot(data=data, x='experience_level', y=metric, hue='experience_level',
                              hue_order=experience_levels, palette=palette, dodge=0, join=False, ax=ax[i])
        ax[i].set_xticklabels(experience_levels, rotation=45)
        #     ax[i].legend(fontsize='xx-small', title='')
        ax[i].get_legend().remove()
        ax[i].set_title(cell_type)
        ax[i].set_xlabel('')
        if ylim is not None:
            ax[i].set_ylim(ylim)
    if format_fig:
        fig.tight_layout()
        fig_title = metric + '_across_containers' + suffix
        plt.suptitle(fig_title, x=0.52, y=1.02, fontsize=16)
    if save_dir:
        utils.save_figure(fig, figsize, save_dir, folder, fig_title)
    return ax
def plot_traces_heatmap(dataset, ax=None, save=False, use_events=False):
    if use_events:
        traces = dataset.events_array.copy()
        vmax = 0.03
        # vmax = np.percentile(traces, 99)
        label = 'event magnitude'
        suffix = '_events'
    else:
        traces = dataset.dff_traces_array
        vmax = np.percentile(traces, 99)
        label = 'dF/F'
        suffix = ''
    if ax is None:
        figsize = (14, 5)
        fig, ax = plt.subplots(figsize=figsize)
    cax = ax.pcolormesh(traces, cmap='magma', vmin=0, vmax=vmax)
    ax.set_ylabel('cells')

    interval_seconds = 5 * 60
    ophys_frame_rate = int(dataset.metadata['ophys_frame_rate'])
    upper_limit, time_interval, frame_interval = get_upper_limit_and_intervals(traces, dataset.ophys_timestamps,
                                                                               ophys_frame_rate)
    ax.set_xticks(np.arange(0, upper_limit, interval_seconds * ophys_frame_rate))
    ax.set_xticklabels(np.arange(0, upper_limit / ophys_frame_rate, interval_seconds))
    ax.set_xlabel('time (seconds)')

    cb = plt.colorbar(cax, pad=0.015)
    cb.set_label(label, labelpad=3)
    if save:
        save_figure(fig, figsize, dataset.analysis_dir, 'experiment_summary',
                    str(dataset.experiment_id) + 'traces_heatmap' + suffix)
    return ax
def plot_lick_raster(trials, ax=None, save_dir=None):
    if ax is None:
        figsize = (5, 10)
        fig, ax = plt.subplots(figsize=figsize)
    for trial in trials.trials_id.values:
        trial_data = trials.iloc[trial]
        # get times relative to change time
        trial_start = trial_data.start_time - trial_data.change_time
        lick_times = [(t - trial_data.change_time) for t in trial_data.lick_times]
        reward_time = [(t - trial_data.change_time) for t in trial_data.reward_times]
        # plot trials as colored rows
        ax.axhspan(trial, trial + 1, -200, 200, color=trial_data.trial_type_color, alpha=.5)
        # plot reward times
        if len(reward_time) > 0:
            ax.plot(reward_time[0], trial + 0.5, '.', color='b', label='reward', markersize=6)
        ax.vlines(trial_start, trial, trial + 1, color='black', linewidth=1)
        # plot lick times
        ax.vlines(lick_times, trial, trial + 1, color='k', linewidth=1)
        # annotate change time
        ax.vlines(0, trial, trial + 1, color=[.5, .5, .5], linewidth=1)
    # gray bar for response window
    ax.axvspan(trial_data.response_window[0], trial_data.response_window[1], facecolor='gray', alpha=.4,
               edgecolor='none')
    ax.grid(False)
    ax.set_ylim(0, len(trials))
    ax.set_xlim([-1, 4])
    ax.set_ylabel('trials')
    ax.set_xlabel('time (sec)')
    ax.set_title('lick raster')
    plt.gca().invert_yaxis()

    if save_dir:
        save_figure(fig, figsize, save_dir, 'behavior', 'lick_raster')
def plot_mean_first_flash_response_by_image_block(analysis, save_dir=None, ax=None):
    fdf = analysis.stimulus_response_df
    fdf.image_block = [int(image_block) for image_block in fdf.image_block.values]
    data = fdf[(fdf.repeat == 1) & (fdf.pref_stim == True)]
    mean_response = data.groupby(['cell_specimen_id']).apply(ut.get_mean_sem)
    mean_response = mean_response.unstack()

    cell_order = np.argsort(mean_response.mean_response.values)
    if ax is None:
        figsize = (15, 5)
        fig, ax = plt.subplots(figsize=figsize)
    ax = sns.pointplot(data=data, x="image_block", y="mean_response", kind="point", hue='cell_specimen_id',
                       hue_order=cell_order,
                       palette='Blues', ax=ax)
    # ax.legend(bbox_to_anchor=(1,1))
    ax.legend_.remove()
    min = mean_response.mean_response.min()
    max = mean_response.mean_response.max()
    norm = plt.Normalize(min, max)
    #     norm = plt.Normalize(0,5)
    sm = plt.cm.ScalarMappable(cmap="Blues", norm=norm)
    sm.set_array([])
    ax.figure.colorbar(mappable=sm, ax=ax, label='mean response across blocks')
    ax.set_title('mean response to first flash of pref stim across image blocks')
    if save_dir:
        fig.tight_layout()
        save_figure(fig, figsize, save_dir, 'first_flash_by_image_block', analysis.dataset.analysis_folder)
    return ax
def plot_container_overview(experiments_table, project_code):
    project_expts = experiments_table[experiments_table.project_code ==
                                      project_code]
    ophys_container_ids = np.sort(project_expts.ophys_container_id.unique())
    max_n_expts = project_expts.groupby(
        'ophys_container_id').count().ophys_experiment_id.max()

    figsize = (15, len(ophys_container_ids))
    fig, ax = plt.subplots(len(ophys_container_ids),
                           1,
                           figsize=figsize,
                           gridspec_kw={
                               'wspace': 0,
                               'hspace': 0
                           })
    for i, ophys_container_id in enumerate(ophys_container_ids):
        ax[i] = plot_expts_for_container(ophys_container_id,
                                         project_expts,
                                         max_n_expts=max_n_expts,
                                         ax=ax[i])
    plt.suptitle('project code: ' + project_code,
                 x=0.5,
                 y=0.89,
                 fontsize=20,
                 horizontalalignment='center')

    save_dir = r'/allen/programs/braintv/workgroups/nc-ophys/visual_behavior/qc_plots'
    ut.save_figure(fig, figsize, save_dir, 'overview_plots',
                   project_code + '_containers_chronological')
Пример #7
0
def plot_metrics_distribution(metrics_df, title, folder):
    metrics = [
        'area', 'ellipseness', 'compactness', 'mean_intensity',
        'max_intensity', 'intensity_ratio', 'soma_minus_np_mean',
        'soma_minus_np_std', 'sig_active_frames_2_5', 'sig_active_frames_4'
    ]
    figsize = (20, 8)
    fig, ax = plt.subplots(2, 5, figsize=figsize)
    ax = ax.ravel()

    for i, metric in enumerate(metrics):
        ax[i] = sns.distplot(
            metrics_df[metrics_df.valid_roi == True][metric].values,
            bins=30,
            ax=ax[i],
            color='blue')
        ax[i] = sns.distplot(
            metrics_df[metrics_df.valid_roi == False][metric].values,
            bins=30,
            ax=ax[i],
            color='red')
        ax[i].set_xlabel(metric)
        ax[i].set_ylabel('density')
        ax[i].legend(['valid', 'invalid'],
                     fontsize='x-small',
                     loc='upper right')

    fig.tight_layout()
    fig.suptitle(title, x=0.5, y=1.01, fontsize=16)
    save_dir = r'/allen/programs/braintv/workgroups/nc-ophys/visual_behavior/qc_plots/roi_filtering_validation'
    utils.save_figure(fig, figsize, save_dir, folder,
                      title + '_metric_distributions')
def plot_single_cell_activity_and_behavior(dataset, cell_specimen_id, save_figure=True):
    """
    Plots the full dFF trace for a cell, along with licking behavior, rewards, running speed, pupil area, and face motion.
    Useful to visualize whether the dFF trace tracks the behavior variables
    """
    figsize = (20, 10)
    fig, ax = plt.subplots(5, 1, figsize=figsize, sharex=True)
    colors = sns.color_palette()

    trace_timestamps = dataset.ophys_timestamps
    trace = dataset.dff_traces.loc[cell_specimen_id].dff
    ax[0].plot(trace_timestamps, trace, label='mean_trace', color=colors[0])
    ax[0].set_ylabel('dF/F')

    lick_timestamps = dataset.licks.timestamps.values
    licks = np.ones(len(lick_timestamps))
    ax[1].plot(lick_timestamps, licks, '|', label='licks', color=colors[3])
    ax[1].set_ylabel('licks')
    ax[1].set_yticklabels([])

    running_speed = dataset.running_speed.speed.values
    running_timestamps = dataset.running_speed.timestamps.values
    ax[2].plot(running_timestamps, running_speed, label='running_speed', color=colors[4])
    ax[2].set_ylabel('run speed\n(cm/s)')

    try:
        pupil_area = dataset.eye_tracking.pupil_area.values
        pupil_timestamps = dataset.eye_tracking.timestamps.values
        ax[3].plot(pupil_timestamps, pupil_area, label='pupil_area', color=colors[9])
    except Exception:
        print('no pupil for', dataset.ophys_experiment_id)
    ax[3].set_ylabel('pupil area\n pixels**2')
    ax[3].set_ylim(-50, 30000)

    try:
        face_motion = dataset.behavior_movie_pc_activations[:, 0]
        face_timestamps = dataset.timestamps['eye_tracking'].timestamps
        ax[4].plot(face_timestamps, face_motion, label='face_motion_PC0', color=colors[2])
    except Exception:
        print('no face motion for', dataset.ophys_experiment_id)
    ax[4].set_ylabel('face motion\n PC0 activation')

    for x in range(5):
        ax[x].tick_params(which='both', bottom=False, top=False, right=False, left=True,
                          labelbottom=False, labeltop=False, labelright=False, labelleft=True)
    ax[4].tick_params(which='both', bottom=False, top=False, right=False, left=True,
                      labelbottom=True, labeltop=False, labelright=False, labelleft=True)
    #     ax[x].legend(loc='upper left', fontsize='x-small')
    plt.subplots_adjust(wspace=0, hspace=0.1)
    ax[0].set_title(str(cell_specimen_id) + '_' + dataset.metadata_string)
    if save_figure:
        utils.save_figure(fig, figsize, utils.get_single_cell_plots_dir(), 'dff_trace_and_behavior',
                          str(cell_specimen_id) + '_' + dataset.metadata_string + '_dff_trace_and_behavior')
        plt.close()
Пример #9
0
def save_roi_validation(roi_validation, lims_data):
    analysis_dir = get_analysis_dir(lims_data)

    for roi in roi_validation:
        fig = roi['fig']
        index = roi['index']
        id = roi['id']
        cell_index = roi['cell_index']

        save_figure(fig, (20, 10), analysis_dir, 'roi_validation',
                    str(index) + '_' + str(id) + '_' + str(cell_index))
Пример #10
0
def plot_n_segmented_cells(multi_session_df, df_name, horizontal=True, save_dir=None, folder='cell_matching', suffix='', ax=None):
    """
    Plots the fraction of responsive cells across cre lines
    :param multi_session_df: dataframe of trial averaged responses for each cell for some set of conditions
    :param df_name: name of the type of response_df used to make multi_session_df, such as 'omission_response_df' or 'stimulus_response_df'
    :param responsiveness_threshold: threshold on fraction_significant_p_value_gray_screen to determine whether a cell is responsive or not
    :param save_dir: directory to save figures to. if None, will not save.
    :param suffix: string starting with '_' to append to end of filename of saved plot
    :return:
    """
    df = multi_session_df.copy()

    experience_levels = np.sort(df.experience_level.unique())
    cell_types = np.sort(df.cell_type.unique())[::-1]

    fraction_responsive = get_fraction_responsive_cells(df, conditions=['cell_type', 'experience_level', 'ophys_container_id', 'ophys_experiment_id'])
    fraction_responsive = fraction_responsive.reset_index()

    palette = utils.get_experience_level_colors()
    if ax is None:
        format_fig = True
        if horizontal:
            figsize = (10, 4)
            fig, ax = plt.subplots(1, 3, figsize=figsize, sharex=False)
        else:
            figsize = (3.5, 10.5)
            fig, ax = plt.subplots(3, 1, figsize=figsize, sharex=True)
    else:
        format_fig = False

    for i, cell_type in enumerate(cell_types):
        data = fraction_responsive[fraction_responsive.cell_type == cell_type]
        for ophys_container_id in data.ophys_container_id.unique():
            ax[i] = sns.pointplot(data=data[data.ophys_container_id == ophys_container_id], x='experience_level', y='total_cells',
                                  color='gray', join=True, markers='.', scale=0.25, errwidth=0.25, ax=ax[i], zorder=500)
        plt.setp(ax[i].collections, alpha=.3)  # for the markers
        plt.setp(ax[i].lines, alpha=.3)
        ax[i] = sns.pointplot(data=data, x='experience_level', y='total_cells', hue='experience_level',
                              hue_order=experience_levels, palette=palette, dodge=0, join=False, ax=ax[i])
        ax[i].set_xticklabels(experience_levels, rotation=45)
    #     ax[i].legend(fontsize='xx-small', title='')
        ax[i].get_legend().remove()
        ax[i].set_title(cell_type)
        ax[i].set_ylim(ymin=0)
        ax[i].set_xlabel('')
#         ax[i].set_ylim(0,1)
    if format_fig:
        fig.tight_layout()
    if save_dir:
        fig_title = df_name.split('-')[0] + '_n_total_cells' + suffix
        utils.save_figure(fig, figsize, save_dir, 'n_segmented_cells', fig_title)
def plot_sorted_datacube_summary(project_experiments_table,
                                 experiment_ids_to_highlight,
                                 what_is_highlighted_string,
                                 save_dir=None):
    expts = project_experiments_table.copy()
    project_code = expts.project_code.unique()[0]
    max_n_expts = expts.groupby(
        ['super_ophys_container_id',
         'ophys_container_id']).count().ophys_session_id.max()
    sorted_super_ophys_container_ids = expts.groupby([
        'cre_line', 'targeted_structure', 'date_of_acquisition',
        'super_ophys_container_id'
    ]).count().reset_index().super_ophys_container_id.unique()
    super_ophys_container_ids = sorted_super_ophys_container_ids
    n_ophys_container_ids = len(expts.ophys_container_id.unique())

    figsize = (15, n_ophys_container_ids)
    fig, ax = plt.subplots(n_ophys_container_ids,
                           1,
                           figsize=figsize,
                           gridspec_kw={
                               'wspace': 0,
                               'hspace': 0
                           })
    i = 0
    for x, super_ophys_container_id in enumerate(super_ophys_container_ids):
        super_container_expts = expts[expts.super_ophys_container_id ==
                                      super_ophys_container_id]
        ophys_container_ids = super_container_expts.ophys_container_id.unique()
        for y, ophys_container_id in enumerate(ophys_container_ids):
            ax[i] = plot_expts_for_container(super_container_expts,
                                             ophys_container_id,
                                             experiment_ids_to_highlight,
                                             max_n_expts=max_n_expts,
                                             ax=ax[i])
            i += 1
    plt.suptitle('project code: ' + project_code + ' - ' +
                 what_is_highlighted_string,
                 x=0.3,
                 y=0.9,
                 fontsize=20,
                 horizontalalignment='center')
    fig.subplots_adjust(left=0.2)
    save_dir = r'\\allen\programs\braintv\workgroups\nc-ophys\visual_behavior\qc_plots'
    if save_dir:
        ut.save_figure(
            fig, figsize, save_dir, 'overview_plots',
            project_code + '_containers_sorted_' + what_is_highlighted_string)
def plot_mean_image_response_heatmap(mean_df, title=None, ax=None, save_dir=None, use_events=False):
    df = mean_df.copy()
    images = np.sort(df.change_image_name.unique())
    if 'cell_specimen_id' in df.keys():
        cell_name = 'cell_specimen_id'
    else:
        cell_name = 'cell'
    cell_list = []
    for image in images:
        tmp = df[(df.change_image_name == image) & (df.pref_stim == True)]
        order = np.argsort(tmp.mean_response.values)[::-1]
        cell_ids = list(tmp[cell_name].values[order])
        cell_list = cell_list + cell_ids

    response_matrix = np.empty((len(cell_list), len(images)))
    for i, cell in enumerate(cell_list):
        responses = []
        for image in images:
            response = df[(df[cell_name] == cell) & (df.change_image_name == image)].mean_response.values[0]
            responses.append(response)
        response_matrix[i, :] = np.asarray(responses)

    if ax is None:
        figsize = (5, 8)
        fig, ax = plt.subplots(figsize=figsize)
    if use_events:
        vmax = 0.03
        label = 'mean event magnitude'
        suffix = '_events'
    else:
        vmax = 0.3
        label = 'mean dF/F'
        suffix = ''
    ax = sns.heatmap(response_matrix, cmap='magma', linewidths=0, linecolor='white', square=False,
                     vmin=0, vmax=vmax, robust=True,
                     cbar_kws={"drawedges": False, "shrink": 1, "label": label}, ax=ax)

    if title is None:
        title = 'mean response by image'
    ax.set_title(title, va='bottom', ha='center')
    ax.set_xticklabels(images, rotation=90)
    ax.set_ylabel('cells')
    interval = 10
    ax.set_yticks(np.arange(0, response_matrix.shape[0], interval))
    ax.set_yticklabels(np.arange(0, response_matrix.shape[0], interval))
    if save_dir:
        fig.tight_layout()
        save_figure(fig, figsize, save_dir, 'experiment_summary', 'mean_image_response_heatmap' + suffix)
def plot_mean_trace_heatmap(mean_df, condition='trial_type', condition_values=['go', 'catch'], ax=None, save_dir=None,
                            use_events=False, window=[-4, 4]):
    data = mean_df[mean_df.pref_stim == True].copy()
    if use_events:
        vmax = 0.05
        suffix = '_events'
    else:
        vmax = 0.5
        suffix = ''
    if ax is None:
        figsize = (3 * len(condition_values), 6)
        fig, ax = plt.subplots(1, len(condition_values), figsize=figsize, sharey=True)
        ax = ax.ravel()

    for i, condition_value in enumerate(condition_values):
        im_df = data[(data[condition] == condition_value)]
        if len(im_df) != 0:
            if i == 0:
                order = np.argsort(im_df.mean_response.values)[::-1]
                cells = im_df.cell_specimen_id.unique()[order]
            len_trace = len(im_df.mean_trace.values[0])
            response_array = np.empty((len(cells), len_trace))
            for x, cell in enumerate(cells):
                tmp = im_df[im_df.cell_specimen_id == cell]
                if len(tmp) >= 1:
                    trace = tmp.mean_trace.values[0]
                else:
                    trace = np.empty((len_trace))
                    trace[:] = np.nan
                response_array[x, :] = trace

            sns.heatmap(data=response_array, vmin=0, vmax=vmax, ax=ax[i], cmap='magma', cbar=False)
            xticks, xticklabels = sf.get_xticks_xticklabels(trace, 31., interval_sec=2, window=window)
            ax[i].set_xticks(xticks)
            ax[i].set_xticklabels([int(x) for x in xticklabels])
            ax[i].set_yticks(np.arange(0, response_array.shape[0], 10))
            ax[i].set_yticklabels(np.arange(0, response_array.shape[0], 10))
            ax[i].set_xlabel('time after change (s)', fontsize=16)
            ax[i].set_title(condition_value)
            ax[0].set_ylabel('cells')

    if save_dir:
        fig.tight_layout()
        save_figure(fig, figsize, save_dir, 'experiment_summary', 'mean_trace_heatmap_' + condition + suffix)
def plot_roi_masks(dataset, save=False):
    figsize = (20, 10)
    fig, ax = plt.subplots(1, 2, figsize=figsize)
    ax = ax.ravel()

    ax[0].imshow(dataset.max_projection, cmap='gray', vmin=0, vmax=np.amax(dataset.max_projection))
    ax[0].axis('off')
    ax[0].set_title('max intensity projection')

    metrics = np.empty(len(dataset.cell_indices))
    metrics[:] = -1
    cell_list = dataset.cell_indices
    plot_metrics_mask(dataset, metrics, cell_list, 'roi masks', max_image=True, cmap='hls', ax=ax[1], save=False,
                      colorbar=False)

    plt.suptitle(dataset.analysis_folder, fontsize=16, x=0.5, y=1., horizontalalignment='center')
    if save:
        save_figure(fig, figsize, dataset.analysis_dir, 'experiment_summary', dataset.analysis_folder + '_roi_masks')
        save_figure(fig, figsize, dataset.cache_dir, 'roi_masks', dataset.analysis_folder + '_roi_masks')
def plot_mean_response_across_image_block_sets(data, analysis_folder, save_dir=None, ax=None):
    order = np.argsort(data[data.image_block == 1].early_late_block_ratio.values)
    cell_order = data[data.image_block == 1].cell.values[order]
    if ax is None:
        figsize = (6, 5)
        fig, ax = plt.subplots(figsize=figsize)
    ax = sns.pointplot(data=data, x="block_set", y="mean_response", kind="point", palette='RdBu', ax=ax,
                       hue='cell_specimen_id', hue_order=cell_order)
    # ax.legend(bbox_to_anchor=(1,1))
    ax.legend_.remove()
    min = np.amin(data.early_late_block_ratio.unique())
    max = np.amax(data.early_late_block_ratio.unique())
    norm = plt.Normalize(min, max)
    #     norm = plt.Normalize(0,5)
    sm = plt.cm.ScalarMappable(cmap="RdBu", norm=norm)
    sm.set_array([])
    ax.figure.colorbar(mappable=sm, ax=ax, label='first/last ratio')
    ax.set_title('mean response across image blocks\ncolored by ratio of first to last block')
    if save_dir:
        fig.tight_layout()
        save_figure(fig, figsize, save_dir, 'first_flash_by_image_block_set', analysis_folder)
    return ax
def plot_lick_triggered_average_for_container(container_id, save_figure=True):

    experiments = loading.get_filtered_ophys_experiment_table()
    container_data = experiments[experiments.container_id == container_id]

    figsize = (6, 4)
    fig, ax = plt.subplots(figsize=figsize)

    for session_number in container_data.session_number.unique():
        experiment_ids = container_data[container_data.session_number ==
                                        session_number].index.values
        for experiment_id in experiment_ids:
            dataset = loading.get_ophys_dataset(experiment_id)
            analysis = ResponseAnalysis(
                dataset,
                use_events=True,
                use_extended_stimulus_presentations=False)
            ldf = analysis.get_response_df(
                df_name='lick_triggered_response_df')
            if len(ldf.cell_specimen_id.unique()) > 5:
                colors = utils.get_colors_for_session_numbers()
                ax = plot_lick_triggered_average(
                    ldf,
                    color=colors[session_number - 1],
                    ylabel='pop. avg. \nevent magnitude',
                    legend_label=session_number,
                    ax=ax)
    ax.legend(loc='upper left', fontsize='x-small')
    fig.tight_layout()
    m = dataset.metadata.copy()
    title = str(m['mouse_id']) + '_' + m['full_genotype'].split(
        '/')[0] + '_' + m['targeted_structure'] + '_' + str(
            m['imaging_depth']) + '_' + m['equipment_name'] + '_' + str(
                m['experiment_container_id'])
    fig.suptitle(title, x=0.53, y=1.02, fontsize=16)
    if save_figure:
        save_dir = loading.get_container_plots_dir()
        utils.save_figure(fig, figsize, save_dir, 'lick_triggered_average',
                          title)
def plot_experiment_summary_figure(analysis, save_dir=None):
    use_events = analysis.use_events
    if use_events:
        traces = analysis.dataset.events_array.copy()
        suffix = '_events'
    else:
        traces = analysis.dataset.dff_traces_array.copy()
        suffix = ''

    interval_seconds = 600
    ophys_frame_rate = int(analysis.ophys_frame_rate)

    figsize = [2 * 11, 2 * 8.5]
    fig = plt.figure(figsize=figsize, facecolor='white')

    ax = placeAxesOnGrid(fig, dim=(1, 1), xspan=(.8, 0.95), yspan=(0, .3))
    table_data = format_table_data(analysis.dataset)
    xtable = ax.table(cellText=table_data.values, cellLoc='left', rowLoc='left', loc='center', fontsize=12)
    xtable.scale(1.5, 3)
    ax.axis('off')

    ax = placeAxesOnGrid(fig, dim=(1, 1), xspan=(.0, .22), yspan=(0, .27))
    # metrics = dataset.cell_indices
    metrics = np.empty(len(analysis.dataset.cell_indices))
    metrics[:] = -1
    cell_list = analysis.dataset.cell_indices
    plot_metrics_mask(analysis.dataset, metrics, cell_list, 'cell masks', max_image=True, cmap='hls', ax=ax, save=False,
                      colorbar=False)
    # ax.imshow(analysis.dataset.max_projection, cmap='gray', vmin=0, vmax=np.amax(analysis.dataset.max_projection))
    ax.set_title(analysis.dataset.experiment_id)
    ax.axis('off')

    upper_limit, time_interval, frame_interval = get_upper_limit_and_intervals(traces,
                                                                               analysis.dataset.ophys_timestamps,
                                                                               analysis.ophys_frame_rate)

    ax = placeAxesOnGrid(fig, dim=(1, 1), xspan=(.22, 0.9), yspan=(0, .3))
    # ax = plot_traces_heatmap(analysis.dataset, ax=ax, use_events=use_events)
    ax = plot_sorted_traces_heatmap(analysis.dataset, analysis, ax=ax, use_events=use_events)
    ax.set_xticks(np.arange(0, upper_limit, interval_seconds * ophys_frame_rate))
    ax.set_xticklabels(np.arange(0, upper_limit / ophys_frame_rate, interval_seconds))
    ax.set_xlabel('time (seconds)')
    ax.set_title(analysis.dataset.analysis_folder)

    ax = placeAxesOnGrid(fig, dim=(1, 1), xspan=(.22, 0.8), yspan=(.26, .41))
    ax = plot_run_speed(analysis.dataset.running_speed.running_speed, analysis.dataset.stimulus_timestamps, ax=ax,
                        label=True)
    ax.set_xlim(time_interval[0], np.uint64(upper_limit / ophys_frame_rate))
    ax.set_xticks(np.arange(interval_seconds, upper_limit / ophys_frame_rate, interval_seconds))
    ax.set_xlabel('time (seconds)')

    ax = placeAxesOnGrid(fig, dim=(1, 1), xspan=(.22, 0.8), yspan=(.37, .52))
    ax = plot_hit_false_alarm_rates(analysis.dataset.trials, ax=ax)
    ax.set_xlim(time_interval[0], np.uint64(upper_limit / ophys_frame_rate))
    ax.set_xticks(np.arange(interval_seconds, upper_limit / ophys_frame_rate, interval_seconds))
    ax.legend(loc='upper right', ncol=2, borderaxespad=0.)
    ax.set_xlabel('time (seconds)')

    ax = placeAxesOnGrid(fig, dim=(1, 1), xspan=(.0, .22), yspan=(.25, .8))
    ax = plot_lick_raster(analysis.dataset.trials, ax=ax, save_dir=None)

    ax = placeAxesOnGrid(fig, dim=(1, 4), xspan=(.2, .8), yspan=(.5, .8), wspace=0.35)
    try:
        mdf = ut.get_mean_df(analysis.trials_response_df, analysis,
                             conditions=['cell_specimen_id', 'change_image_name', 'behavioral_response_type'])
        ax = plot_mean_trace_heatmap(mdf, condition='behavioral_response_type',
                                     condition_values=['HIT', 'MISS', 'CR', 'FA'], ax=ax, save_dir=None,
                                     use_events=use_events, window=analysis.trial_window)
    except BaseException:
        pass

    ax = placeAxesOnGrid(fig, dim=(1, 1), xspan=(.78, 0.97), yspan=(.3, .8))
    mdf = ut.get_mean_df(analysis.trials_response_df, analysis, conditions=['cell_specimen_id', 'change_image_name'])
    ax = plot_mean_image_response_heatmap(mdf, title=None, ax=ax, save_dir=None, use_events=use_events)

    # fig.canvas.draw()
    fig.tight_layout()

    if save_dir:
        save_figure(fig, figsize, save_dir, 'experiment_summary_figures',
                    str(analysis.dataset.experiment_id) + '_experiment_summary' + suffix)
Пример #18
0
def plot_behavior_timeseries(dataset, start_time, duration_seconds=20, xlim_seconds=None, save_dir=None, ax=None):
    """
    Plots licking behavior, rewards, running speed, and pupil area for a defined window of time
    """
    if xlim_seconds is None:
        xlim_seconds = [start_time - (duration_seconds / 4.), start_time + duration_seconds * 2]
    else:
        if start_time != xlim_seconds[0]:
            start_time = xlim_seconds[0]

    lick_timestamps = dataset.licks.timestamps.values
    licks = np.ones(len(lick_timestamps))
    licks[:] = -2

    reward_timestamps = dataset.rewards.timestamps.values
    rewards = np.zeros(len(reward_timestamps))
    rewards[:] = -4

    running_speed = dataset.running_speed.speed.values
    running_timestamps = dataset.running_speed.timestamps.values

    eye_tracking = dataset.eye_tracking.copy()
    pupil_diameter = eye_tracking.pupil_width.values
    pupil_diameter[eye_tracking.likely_blink == True] = np.nan
    pupil_timestamps = eye_tracking.timestamps.values

    if ax is None:
        figsize = (15, 2.5)
        fig, ax = plt.subplots(1, 1, figsize=figsize)
    colors = sns.color_palette()

    ln0 = ax.plot(lick_timestamps, licks, '|', label='licks', color=colors[3], markersize=10)
    ln1 = ax.plot(reward_timestamps, rewards, 'o', label='rewards', color=colors[9], markersize=10)

    ln2 = ax.plot(running_timestamps, running_speed, label='running_speed', color=colors[2], zorder=100)
    ax.set_ylabel('running speed\n(cm/s)')
    ax.set_ylim(ymin=-8)

    ax2 = ax.twinx()
    ln3 = ax2.plot(pupil_timestamps, pupil_diameter, label='pupil_diameter', color=colors[4], zorder=0)

    ax2.set_ylabel('pupil diameter \n(pixels)')
    #     ax2.set_ylim(0, 200)

    axes_to_label = ln0 + ln1 + ln2 + ln3  # +ln4
    labels = [label.get_label() for label in axes_to_label]
    ax.legend(axes_to_label, labels, bbox_to_anchor=(1, 1), fontsize='small')

    ax = add_stim_color_span(dataset, ax, xlim=xlim_seconds)

    ax.set_xlim(xlim_seconds)
    ax.set_xlabel('time in session (seconds)')
    metadata_string = utils.get_metadata_string(dataset.metadata)
    ax.set_title(metadata_string)

    # ax.tick_params(which='both', bottom=True, top=False, right=False, left=True,
    #                 labelbottom=True, labeltop=False, labelright=True, labelleft=True)
    # ax2.tick_params(which='both', bottom=True, top=False, right=True, left=False,
    #                 labelbottom=True, labeltop=False, labelright=True, labelleft=True)
    if save_dir:
        folder = 'behavior_timeseries'
        utils.save_figure(fig, figsize, save_dir, folder, metadata_string + '_' + str(int(start_time)),
                          formats=['.png'])
    return ax
Пример #19
0
def plot_behavior_timeseries_stacked(dataset, start_time, duration_seconds=20,
                                     label_changes=True, label_omissions=True,
                                     save_dir=None, ax=None):
    """
    Plots licking behavior, rewards, running speed, and pupil area for a defined window of time.
    Each timeseries gets its own row. If label_changes=True, all flashes are gray, changes are blue.
    If label_changes=False, unique colors are given to each image.
    If label_omissions=True, a dotted line will be plotted at the time of omissions.
    """

    if label_changes:
        suffix = '_changes'
    else:
        suffix = '_colors'

    xlim_seconds = [start_time - (duration_seconds / 4.), start_time + duration_seconds * 2]

    lick_timestamps = dataset.licks.timestamps.values
    licks = np.ones(len(lick_timestamps))
    licks[:] = -2

    reward_timestamps = dataset.rewards.timestamps.values
    rewards = np.zeros(len(reward_timestamps))
    rewards[:] = -4

    # get run speed trace and timestamps
    running_speed = dataset.running_speed.speed.values
    running_timestamps = dataset.running_speed.timestamps.values
    # limit running trace to window so yaxes scale properly
    start_ind = np.where(running_timestamps < xlim_seconds[0])[0][-1]
    stop_ind = np.where(running_timestamps > xlim_seconds[1])[0][0]
    running_speed = running_speed[start_ind:stop_ind]
    running_timestamps = running_timestamps[start_ind:stop_ind]

    # get pupil width trace and timestamps
    eye_tracking = dataset.eye_tracking.copy()
    pupil_diameter = eye_tracking.pupil_width.values
    pupil_diameter[eye_tracking.likely_blink == True] = np.nan
    pupil_timestamps = eye_tracking.timestamps.values
    # smooth pupil diameter
    from scipy.signal import medfilt
    pupil_diameter = medfilt(pupil_diameter, kernel_size=5)
    # limit pupil trace to window so yaxes scale properly
    start_ind = np.where(pupil_timestamps < xlim_seconds[0])[0][-1]
    stop_ind = np.where(pupil_timestamps > xlim_seconds[1])[0][0]
    pupil_diameter = pupil_diameter[start_ind:stop_ind]
    pupil_timestamps = pupil_timestamps[start_ind:stop_ind]

    if ax is None:
        figsize = (15, 5)
        fig, ax = plt.subplots(4, 1, figsize=figsize, sharex=True, gridspec_kw={'height_ratios': [1, 1, 3, 3]})
        ax = ax.ravel()

    colors = sns.color_palette()

    ax[0].plot(lick_timestamps, licks, '|', label='licks', color=colors[3], markersize=10)
    ax[0].set_yticklabels([])
    ax[0].set_ylabel('licks', rotation=0, horizontalalignment='right', verticalalignment='center')
    ax[0].tick_params(which='both', bottom=False, top=False, right=False, left=False,
                      labelbottom=False, labeltop=False, labelright=False, labelleft=False)

    ax[1].plot(reward_timestamps, rewards, 'o', label='rewards', color=colors[8], markersize=10)
    ax[1].set_yticklabels([])
    ax[1].set_ylabel('rewards', rotation=0, horizontalalignment='right', verticalalignment='center')
    ax[1].tick_params(which='both', bottom=False, top=False, right=False, left=False,
                      labelbottom=False, labeltop=False, labelright=False, labelleft=False)

    ax[2].plot(running_timestamps, running_speed, label='running_speed', color=colors[2], zorder=100)
    ax[2].set_ylabel('running\nspeed\n(cm/s)', rotation=0, horizontalalignment='right', verticalalignment='center')
    ax[2].set_ylim(ymin=-8)

    ax[3].plot(pupil_timestamps, pupil_diameter, label='pupil_diameter', color=colors[4], zorder=0)
    ax[3].set_ylabel('pupil\ndiameter\n(pixels)', rotation=0, horizontalalignment='right', verticalalignment='center')

    for i in range(4):
        ax[i] = add_stim_color_span(dataset, ax[i], xlim=xlim_seconds, label_changes=label_changes, label_omissions=label_omissions)
        ax[i].set_xlim(xlim_seconds)
        ax[i].tick_params(which='both', bottom=False, top=False, right=False, left=True,
                          labelbottom=False, labeltop=False, labelright=False, labelleft=True)
        sns.despine(ax=ax[i], bottom=True)
    sns.despine(ax=ax[i], bottom=False)

    # label bottom row of plot
    ax[i].set_xlabel('time in session (seconds)')
    ax[i].tick_params(which='both', bottom=True, top=False, right=False, left=True,
                      labelbottom=True, labeltop=False, labelright=False, labelleft=True)
    # add title to top row
    metadata_string = utils.get_metadata_string(dataset.metadata)
    ax[0].set_title(metadata_string)

    plt.subplots_adjust(hspace=0)
    if save_dir:
        folder = 'behavior_timeseries_stacked'
        utils.save_figure(fig, figsize, save_dir, folder, metadata_string + '_' + str(int(start_time)) + '_' + suffix,
                          formats=['.png', '.pdf'])
    return ax
Пример #20
0
def plot_matched_roi_and_trace(ophys_container_id, cell_specimen_id, limit_to_last_familiar_second_novel=True,
                               use_events=False, filter_events=False, save_figure=True):
    """
    Generates plots characterizing single cell activity in response to stimulus, omissions, and changes.
    Compares across all sessions in a container for each cell, including the ROI mask across days.
    Useful to validate cell matching as well as examine changes in activity profiles over days.
    """
    experiments_table = loading.get_platform_paper_experiment_table()
    if limit_to_last_familiar_second_novel:  # this ensures only one session per experience level
        experiments_table = utilities.limit_to_last_familiar_second_novel_active(experiments_table)
        experiments_table = utilities.limit_to_containers_with_all_experience_levels(experiments_table)

    container_expts = experiments_table[experiments_table.ophys_container_id == ophys_container_id]
    container_expts = container_expts.sort_values(by=['experience_level'])
    expts = np.sort(container_expts.index.values)

    if use_events:
        if filter_events:
            suffix = 'filtered_events'
        else:
            suffix = 'events'
        ylabel = 'response'
    else:
        suffix = 'dff'
        ylabel = 'dF/F'

    n = len(expts)
    if limit_to_last_familiar_second_novel:
        figsize = (9, 6)
        folder = 'matched_cells_exp_levels'
    else:
        figsize = (20, 6)
        folder = 'matched_cells_all_sessions'
    fig, ax = plt.subplots(2, n, figsize=figsize, sharey='row')
    ax = ax.ravel()
    print('ophys_container_id:', ophys_container_id)
    for i, ophys_experiment_id in enumerate(expts):
        print('ophys_experiment_id:', ophys_experiment_id)
        try:
            dataset = loading.get_ophys_dataset(ophys_experiment_id, get_extended_stimulus_presentations=False)
            if cell_specimen_id in dataset.dff_traces.index:

                ct = dataset.cell_specimen_table.copy()
                cell_roi_id = ct.loc[cell_specimen_id].cell_roi_id
                roi_masks = dataset.roi_masks.copy()  # save this to get approx ROI position if subsequent session is missing the ROI (fails if the first session is the one missing the ROI)
                ax[i] = sf.plot_cell_zoom(dataset.roi_masks, dataset.max_projection, cell_roi_id,
                                          spacex=50, spacey=50, show_mask=True, ax=ax[i])
                ax[i].set_title(container_expts.loc[ophys_experiment_id].experience_level)

                analysis = ResponseAnalysis(dataset, use_events=use_events, filter_events=filter_events,
                                            use_extended_stimulus_presentations=False)
                sdf = analysis.get_response_df(df_name='stimulus_response_df')
                cell_data = sdf[(sdf.cell_specimen_id == cell_specimen_id) & (sdf.is_change == True)]

                window = rp.get_default_stimulus_response_params()["window_around_timepoint_seconds"]
                ax[i + n] = utils.plot_mean_trace(cell_data.trace.values, cell_data.trace_timestamps.values[0],
                                                  ylabel=ylabel, legend_label=None, color='gray', interval_sec=0.5,
                                                  xlim_seconds=window, plot_sem=True, ax=ax[i + n])

                ax[i + n] = utils.plot_flashes_on_trace(ax[i + n], cell_data.trace_timestamps.values[0], change=True, omitted=False,
                                                        alpha=0.15, facecolor='gray')
                ax[i + n].set_title('')
                if i != 0:
                    ax[i + n].set_ylabel('')
            else:
                # plot the max projection image with the xy location of the previous ROI
                # this will fail if the familiar session is the one without the cell matched
                ax[i] = sf.plot_cell_zoom(roi_masks, dataset.max_projection, cell_roi_id,
                                          spacex=50, spacey=50, show_mask=False, ax=ax[i])
                ax[i].set_title(container_expts.loc[ophys_experiment_id].experience_level)

            metadata_string = utils.get_metadata_string(dataset.metadata)

            fig.tight_layout()
            fig.suptitle(str(cell_specimen_id) + '_' + metadata_string, x=0.53, y=1.02,
                         horizontalalignment='center', fontsize=16)
        except Exception as e:
            print('problem for cell_specimen_id:', cell_specimen_id, ', ophys_experiment_id:', ophys_experiment_id)
            print(e)
    if save_figure:
        save_dir = r'//allen/programs/braintv/workgroups/nc-ophys/visual_behavior/platform_paper_plots/cell_matching'
        utils.save_figure(fig, figsize, save_dir, folder, str(cell_specimen_id) + '_' + metadata_string + '_' + suffix)
        plt.close()
Пример #21
0
def plot_response_heatmaps_for_conditions(multi_session_df, timestamps, data_type, event_type,
                                          row_condition, col_condition, cols_to_sort_by=None, suptitle=None,
                                          microscope=None, vmax=0.05, xlim_seconds=None, match_cells=False, cbar=True,
                                          save_dir=None, folder=None, suffix='', ax=None):
    sdf = multi_session_df.copy()

    if 'omission' in event_type:
        xlabel = 'time after omission (s)'
    elif 'change' in event_type:
        xlabel = 'time after change (s)'
    else:
        xlabel = 'time (s)'

    if xlim_seconds is None:
        xlim_seconds = (timestamps[0], timestamps[-1])

    row_conditions = np.sort(sdf[row_condition].unique())
    col_conditions = np.sort(sdf[col_condition].unique())

    if ax is None:
        figsize = (3 * len(col_conditions), 3 * len(row_conditions))
        fig, ax = plt.subplots(len(row_conditions), len(col_conditions), figsize=figsize, sharex=True)
        ax = ax.ravel()

    i = 0
    for r, row in enumerate(row_conditions):
        row_sdf = sdf[(sdf[row_condition] == row)]
        for c, col in enumerate(col_conditions):

            if row == 'Excitatory':
                vmax = 0.01
            elif row == 'Vip Inhibitory':
                vmax = 0.02
            elif row == 'Sst Inhibitory':
                vmax = 0.03
            else:
                vmax = 0.02

            tmp = row_sdf[(row_sdf[col_condition] == col)]
            tmp = tmp.reset_index()
            if cols_to_sort_by:
                tmp = tmp.sort_values(by=cols_to_sort_by, ascending=True)
            else:
                if match_cells:
                    if c == 0:
                        tmp = tmp.sort_values(by='mean_response', ascending=True)
                        order = tmp.index.values
                    else:
                        tmp = tmp.loc[order]
                else:
                    tmp = tmp.sort_values(by='mean_response', ascending=True)
            data = pd.DataFrame(np.vstack(tmp.mean_trace.values), columns=timestamps)
            n_cells = len(data)

            ax[i] = plot_cell_response_heatmap(data, timestamps, vmax=vmax, xlabel=xlabel, cbar=cbar,
                                               microscope=microscope, ax=ax[i])
            ax[i].set_title(row + '\n' + col)
            # label y with total number of cells
            ax[i].set_yticks([0, n_cells])
            ax[i].set_yticklabels([0, n_cells], fontsize=12)
            # set xticks to every 1 second, assuming 30Hz traces
            ax[i].set_xticks(np.arange(0, len(timestamps), 30))  # assuming 30Hz traces
            ax[i].set_xticklabels([int(t) for t in timestamps[::30]])
            # set xlims according to input
            start_index = np.where(timestamps == xlim_seconds[0])[0][0]
            end_index = np.where(timestamps == xlim_seconds[1])[0][0]
            xlims = [start_index, end_index]
            ax[i].set_xlim(xlims)
            ax[i].set_ylabel('')

            if r == len(row_conditions) - 1:
                ax[i].set_xlabel(xlabel)
            else:
                ax[i].set_xlabel('')
            i += 1

    for i in np.arange(0, (len(col_conditions) * len(row_conditions)), len(col_conditions)):
        ax[i].set_ylabel('cells')

    if suptitle:
        plt.suptitle(suptitle, x=0.52, y=1.04, fontsize=18)
    fig.tight_layout()

    if save_dir:
        fig_title = event_type + '_response_heatmap_' + data_type + '_' + col_condition + '_' + row_condition + '_' + suffix
        utils.save_figure(fig, figsize, save_dir, folder, fig_title)

    return ax
Пример #22
0
def plot_matched_roi_and_traces_example(cell_metadata, include_omissions=True,
                                        use_events=False, filter_events=False, save_dir=None, folder=None):
    """
    Plots the ROI masks and cell traces for a cell matched across sessions
    Cell_metadata is a subset of the ophys_cells_table limited to the cell_specimen_id of interest
    Masks and traces will be plotted for all ophys_experiment_ids in the cell_metadata table
    To limit to a single session of each type, set last_familiar_second_novel to True
    ROI mask for each ophys_experiment_id in cell_metadata is plotted on its own axis
    Average cell traces across all experiments are plotted on a single axis with each trace colored by its experience_level
    if include_omissions is True, there will be one axis for the change response and one axis for the omission response across sessions
    if include_omissions is False, only change responses will be plotted
    Only plots data for ophys_experiment_ids where the cell_specimen_id is present, does not plot max projections without an ROI mask for expts in a container where the cell was not detected
    To generate plots showing max projections from experiments in a container where a cell was not detected, use plot_matched_roi_and_trace
    """

    if len(cell_metadata.cell_specimen_id.unique()) > 1:
        print('There is more than one cell_specimen_id in the provided cell_metadata table')
        print('Please limit input to a single cell_specimen_id')

    # get relevant info for this cell
    cell_metadata = cell_metadata.sort_values(by='experience_level')
    cell_specimen_id = cell_metadata.cell_specimen_id.unique()[0]
    ophys_container_id = cell_metadata.ophys_container_id.unique()[0]
    ophys_experiment_ids = cell_metadata.ophys_experiment_id.unique()
    n_expts = len(ophys_experiment_ids)

    # set up labels for different trace types
    if use_events:
        if filter_events:
            suffix = 'filtered_events'
        else:
            suffix = 'events'
        ylabel = 'response'
    else:
        suffix = 'dff'
        ylabel = 'dF/F'

    # number of columns is one for each experiments ROI mask, plus additional columns for stimulus and omission traces
    if include_omissions:
        n_cols = n_expts + 2
    else:
        n_cols = n_expts + 1

    experience_levels = ['Familiar', 'Novel 1', 'Novel >1']
    colors = utils.get_experience_level_colors()

    figsize = (3 * n_cols, 3)
    fig, ax = plt.subplots(1, n_cols, figsize=figsize)

    print('cell_specimen_id:', cell_specimen_id)
    print('ophys_container_id:', ophys_container_id)
    for i, ophys_experiment_id in enumerate(ophys_experiment_ids):
        print('ophys_experiment_id:', ophys_experiment_id)
        experience_level = \
            cell_metadata[cell_metadata.ophys_experiment_id == ophys_experiment_id].experience_level.values[0]
        ind = experience_levels.index(experience_level)
        color = colors[ind]
        try:
            dataset = loading.get_ophys_dataset(ophys_experiment_id, get_extended_stimulus_presentations=False)
            if cell_specimen_id in dataset.dff_traces.index:

                ct = dataset.cell_specimen_table.copy()
                cell_roi_id = ct.loc[cell_specimen_id].cell_roi_id
                ax[i] = sf.plot_cell_zoom(dataset.roi_masks, dataset.max_projection, cell_roi_id,
                                          spacex=50, spacey=50, show_mask=True, ax=ax[i])
                ax[i].set_title(experience_level)

                # get change responses and plot on second to last axis
                window = [-1, 1.5]  # window around event
                sdf = loading.get_stimulus_response_df(dataset, time_window=window, interpolate=True,
                                                       output_sampling_rate=30,
                                                       data_type='events', event_type='changes',
                                                       load_from_file=True)
                cell_data = sdf[(sdf.cell_specimen_id == cell_specimen_id) & (sdf.is_change == True)]

                ax[n_expts] = utils.plot_mean_trace(cell_data.trace.values, cell_data.trace_timestamps.values[0],
                                                    ylabel=ylabel, legend_label=None, color=color, interval_sec=1,
                                                    xlim_seconds=window, plot_sem=True, ax=ax[n_expts])
                ax[n_expts] = utils.plot_flashes_on_trace(ax[n_expts], cell_data.trace_timestamps.values[0],
                                                          change=True, omitted=False)
                ax[n_expts].set_title('changes')

                # get omission responses and plot on last axis
                if include_omissions:
                    sdf = loading.get_stimulus_response_df(dataset, time_window=window, interpolate=True,
                                                           output_sampling_rate=30,
                                                           data_type='events', event_type='omissions',
                                                           load_from_file=True)
                    cell_data = sdf[(sdf.cell_specimen_id == cell_specimen_id) & (sdf.omitted == True)]

                    ax[n_expts + 1] = utils.plot_mean_trace(cell_data.trace.values,
                                                            cell_data.trace_timestamps.values[0],
                                                            ylabel=ylabel, legend_label=None, color=color,
                                                            interval_sec=1,
                                                            xlim_seconds=window, plot_sem=True, ax=ax[n_expts + 1])
                    ax[n_expts + 1] = utils.plot_flashes_on_trace(ax[n_expts + 1],
                                                                  cell_data.trace_timestamps.values[0],
                                                                  change=False, omitted=True)
                    ax[n_expts + 1].set_title('omissions')

            metadata_string = utils.get_metadata_string(dataset.metadata)

            fig.tight_layout()
            fig.suptitle(str(cell_specimen_id) + '_' + metadata_string, x=0.53, y=1.02,
                         horizontalalignment='center', fontsize=16)
        except Exception as e:
            print('problem for cell_specimen_id:', cell_specimen_id, ', ophys_experiment_id:', ophys_experiment_id)
            print(e)
    if save_dir:
        utils.save_figure(fig, figsize, save_dir, folder,
                          str(cell_specimen_id) + '_' + metadata_string + '_' + suffix)
        plt.close()
Пример #23
0
        ax[0].set_title('cell_roi_id: ' + str(cell_roi_id) +
                        ', dF/F traces before and after decrosstalk')
        ax[0].legend(loc='upper right')

        ax[1].plot(
            dff_traces_post_decrosstalk[dff_traces_post_decrosstalk.cell_roi_id
                                        == cell_roi_id].dff.values[0],
            color='blue',
            label='with decrosstalk')
        ax[1].set_title('production dF/F trace without decrosstalk')

        ax[2].plot(
            dff_traces_pre_decrosstalk[dff_traces_pre_decrosstalk.cell_roi_id
                                       == cell_roi_id].dff.values[0],
            color='black',
            label='without decrosstalk')
        ax[2].set_title('dev dF/F trace after decrosstalk')

        for i in range(3):
            ax[i].set_xlabel('2P frames')
            ax[i].set_ylabel('dF/F')

        fig.tight_layout()
        title = dataset.metadata_string
        plt.suptitle(title, x=0.53, y=1.02, fontsize=16)

        save_dir = r'/allen/programs/braintv/workgroups/nc-ophys/visual_behavior/qc_plots/decrosstalk_validation'
        utils.save_figure(fig, figsize, save_dir,
                          'dFF_before_and_after_decrosstalk_comparison',
                          title + '_' + str(cell_roi_id))
        plt.close()
Пример #24
0
def plot_roi_metrics_for_cell(dataset, metrics_df, cell_specimen_id, title):
    # make roi masks dict
    cell_table = dataset.cell_specimen_table.copy()
    roi_masks = get_roi_masks_dict(cell_table)

    figsize = (15, 8)
    fig, ax = plt.subplots(figsize=figsize, nrows=2, ncols=4)

    # get flattened segmentation mask and binarize
    boolean_mask = cell_table.loc[cell_specimen_id].roi_mask
    binary_mask = np.zeros(boolean_mask.shape)
    binary_mask[:] = np.nan
    binary_mask[boolean_mask == True] = 1

    ax[0, 0].imshow(dataset.max_projection.data, cmap='gray')
    ax[0, 0].imshow(binary_mask, cmap='hsv', vmin=0, vmax=1, alpha=0.5)

    ax[0, 1] = sf.plot_cell_zoom(roi_masks,
                                 dataset.max_projection.data,
                                 cell_specimen_id,
                                 spacex=40,
                                 spacey=40,
                                 show_mask=True,
                                 ax=ax[0, 1])

    ax[0, 2] = sf.plot_cell_zoom(roi_masks,
                                 dataset.max_projection.data,
                                 cell_specimen_id,
                                 spacex=40,
                                 spacey=40,
                                 show_mask=False,
                                 ax=ax[0, 2])

    metrics = [
        'valid_roi', 'area', 'ellipseness', 'compactness', 'mean_intensity',
        'max_intensity', 'intensity_ratio', 'soma_minus_np_mean',
        'soma_minus_np_std', 'sig_active_frames_2_5', 'sig_active_frames_4'
    ]
    cell_metrics = metrics_df[metrics_df.cell_specimen_id == cell_specimen_id]
    string = ''
    for metric in metrics[:7]:
        string = string + metric + ': ' + str(
            cell_metrics[metric].values[0]) + '\n'
    ax[0, 3].text(x=0, y=0, s=string)

    for i in range(1, 4):
        ax[0, i].axis('off')
    gs = ax[1, 0].get_gridspec()
    for ax in ax[1, :]:
        ax.remove()
    trace_ax = fig.add_subplot(gs[1, :])
    trace_ax = sf.plot_trace(dataset.ophys_timestamps,
                             dataset.dff_traces.loc[cell_specimen_id].dff,
                             ax=trace_ax,
                             title='cell_specimen_id: ' +
                             str(cell_specimen_id),
                             ylabel='dF/F')

    fig.tight_layout()
    title = title + '_' + str(cell_specimen_id)
    plt.suptitle(title, x=0.5, y=1.01, fontsize=18)
    save_dir = r'/allen/programs/braintv/workgroups/nc-ophys/visual_behavior/qc_plots/roi_filtering_validation'
    utils.save_figure(fig, figsize, save_dir, 'single_cell_metrics', title)
Пример #25
0
def plot_metric_range_dataset(dataset,
                              cell_specimen_table,
                              max_projection,
                              metrics_df,
                              metric,
                              thresholds,
                              title,
                              less_than=False):
    ct = cell_specimen_table.copy()

    figsize = (20, 10)
    fig, ax = plt.subplots(2, 4, figsize=figsize)
    ax = ax.ravel()

    all_roi_mask_dict = {}
    for cell_roi_id in ct.cell_roi_id.values:
        all_roi_mask_dict[cell_roi_id] = ct[ct.cell_roi_id ==
                                            cell_roi_id].roi_mask.values[0]
    ax[0] = plot_metrics_mask(all_roi_mask_dict,
                              max_projection,
                              metric_values=None,
                              title='all ROIs',
                              cmap='hsv',
                              cmap_range=[0, 1],
                              ax=ax[0],
                              colorbar=False)

    valid_roi_mask_dict = {}
    for cell_roi_id in ct[ct.valid_roi == True].cell_roi_id.values:
        valid_roi_mask_dict[cell_roi_id] = ct[ct.cell_roi_id ==
                                              cell_roi_id].roi_mask.values[0]
    ax[1] = plot_metrics_mask(valid_roi_mask_dict,
                              max_projection,
                              metric_values=None,
                              title='valid ROIs',
                              cmap='hsv',
                              cmap_range=[0, 1],
                              ax=ax[1],
                              colorbar=False)

    for i, threshold in enumerate(thresholds):

        i = i + 2

        filtered_roi_mask_dict = {}
        if less_than:
            filtered_roi_ids = metrics_df[
                metrics_df[metric] < threshold].cell_roi_id.values
        else:
            filtered_roi_ids = metrics_df[
                metrics_df[metric] > threshold].cell_roi_id.values
        for cell_roi_id in filtered_roi_ids:
            filtered_roi_mask_dict[cell_roi_id] = ct[
                ct.cell_roi_id == cell_roi_id].roi_mask.values[0]
        if less_than:
            ax[i] = plot_metrics_mask(filtered_roi_mask_dict,
                                      max_projection,
                                      metric_values=None,
                                      title=metric + ' < ' + str(threshold),
                                      cmap='hsv',
                                      cmap_range=[0, 1],
                                      ax=ax[i],
                                      colorbar=False)
        else:
            ax[i] = plot_metrics_mask(filtered_roi_mask_dict,
                                      max_projection,
                                      metric_values=None,
                                      title=metric + ' > ' + str(threshold),
                                      cmap='hsv',
                                      cmap_range=[0, 1],
                                      ax=ax[i],
                                      colorbar=False)
        ax[i].axis('off')

    fig.tight_layout()
    plt.suptitle(title, x=0.5, y=1.01, fontsize=18)
    save_dir = r'/allen/programs/braintv/workgroups/nc-ophys/visual_behavior/qc_plots/roi_filtering_validation'
    utils.save_figure(fig, figsize, save_dir, metric, title + '_' + metric)
def plot_experiment_summary_figure(experiment_id, save_figure=True):

    dataset = loading.get_ophys_dataset(experiment_id)
    analysis = ra.ResponseAnalysis(dataset, use_events=True)

    fig, ax, figsize = make_fig_ax()
    ax['0_0'] = ep.plot_max_intensity_projection_for_experiment(experiment_id,
                                                                ax=ax['0_0'])
    ax['0_0'].set_title('max projection')
    ax['0_1'] = ep.plot_valid_segmentation_mask_outlines_per_cell_for_experiment(
        experiment_id, ax=ax['0_1'])
    # ax['0_0'].set_title('max projection')
    ax['0_2'] = ep.plot_valid_and_invalid_segmentation_mask_overlay_per_cell_for_experiment(
        experiment_id, ax=ax['0_2'])
    ax['0_2'].set_title('red = valid ROIs, blue = invalid ROIs')
    ax['0_3:'] = ep.plot_motion_correction_and_population_average(
        experiment_id, ax=ax['0_3:'])

    ax['1_0'] = ep.plot_average_image_for_experiment(experiment_id,
                                                     ax=ax['1_0'])
    ax['1_1'] = ep.plot_average_image_for_experiment(experiment_id,
                                                     ax=ax['1_1'])
    try:
        ax['1_2'] = ep.plot_remaining_decrosstalk_masks_for_experiment(
            experiment_id, ax=ax['1_2'])
    except BaseException:
        print('no decrosstalk for experiment', experiment_id)
    ax['1_3:'] = ep.plot_behavior_timeseries_for_experiment(experiment_id,
                                                            ax=ax['1_3:'])

    # ax['2_0'] = population_image_selectivity(experiment_id, ax=ax['2_0'])
    # ax['2_0'] = ep.plot_average_image_for_experiment(experiment_id, ax=ax['2_1'])

    ax['2_2'] = ep.plot_cell_snr_distribution_for_experiment(experiment_id,
                                                             ax=ax['2_2'])
    ax['2_3:'] = ep.plot_traces_heatmap_for_experiment(experiment_id,
                                                       ax=ax['2_3:'])

    df_name = 'trials_response_df'
    df = analysis.get_response_df(df_name)
    mean_df = ut.get_mean_df(df,
                             analysis=analysis,
                             conditions=['cell_specimen_id', 'go'],
                             flashes=False,
                             omitted=False,
                             get_reliability=False,
                             get_pref_stim=False,
                             exclude_omitted_from_pref_stim=True)
    ax['3_0'] = ep.plot_population_average_for_experiment(experiment_id,
                                                          df,
                                                          mean_df,
                                                          df_name,
                                                          color=None,
                                                          label=None,
                                                          ax=ax['3_0'])
    ax['3_0'].set_xlim(-2.5, 2.8)

    df_name = 'omission_response_df'
    df = analysis.get_response_df(df_name)
    mean_df = ut.get_mean_df(df,
                             analysis=analysis,
                             conditions=['cell_specimen_id'],
                             flashes=False,
                             omitted=True,
                             get_reliability=False,
                             get_pref_stim=False,
                             exclude_omitted_from_pref_stim=False)
    ax['3_1'] = ep.plot_population_average_for_experiment(experiment_id,
                                                          df,
                                                          mean_df,
                                                          df_name,
                                                          color=None,
                                                          label=None,
                                                          ax=ax['3_1'])
    ax['3_1'].set_xlim(-2.5, 2.8)

    df_name = 'trials_run_speed_df'
    df = analysis.get_response_df(df_name)
    df['condition'] = True
    mean_df = ut.get_mean_df(df,
                             analysis=analysis,
                             conditions=['condition', 'go'],
                             flashes=False,
                             omitted=True,
                             get_reliability=False,
                             get_pref_stim=False,
                             exclude_omitted_from_pref_stim=False)
    ax['4_0'] = ep.plot_population_average_for_experiment(
        experiment_id,
        df,
        df,
        df_name,
        trace_type='trace',
        color=sns.color_palette()[4],
        label=None,
        ax=ax['4_0'])
    ax['4_0'].set_ylabel('run speed (cm/s)')
    ax['4_0'].set_xlim(-2.5, 2.8)

    df_name = 'omission_run_speed_df'
    df = analysis.get_response_df(df_name)
    df['condition'] = True
    mean_df = ut.get_mean_df(df,
                             analysis=analysis,
                             conditions=['condition'],
                             flashes=False,
                             omitted=False,
                             get_reliability=False,
                             get_pref_stim=True,
                             exclude_omitted_from_pref_stim=True)
    ax['4_1'] = ep.plot_population_average_for_experiment(
        experiment_id,
        df,
        df,
        df_name,
        trace_type='trace',
        color=sns.color_palette()[4],
        label=None,
        ax=ax['4_1'])
    ax['4_1'].set_ylabel('run speed (cm/s)')
    ax['4_1'].set_xlim(-2.5, 2.8)

    xlim_seconds = [int(10 * 60), int(15 * 60)]
    ax['3_3:'] = ep.plot_high_low_snr_trace_examples(experiment_id,
                                                     xlim_seconds=xlim_seconds,
                                                     ax=ax['3_3:'])

    ax['4_3:'] = ep.plot_behavior_timeseries_for_experiment(
        experiment_id, xlim_seconds=xlim_seconds, ax=ax['4_3:'])

    fig.tight_layout()
    title = dataset.metadata_string
    plt.suptitle(title, x=0.5, y=.91, fontsize=20)
    if save_figure:
        # save_dir = r'\\allen\programs\braintv\workgroups\nc-ophys\visual_behavior\qc_plots\experiment_plots'
        save_dir = loading.get_experiment_plots_dir()
        utils.save_figure(fig, figsize, save_dir, 'experiment_summary_figure',
                          title)
def plot_matched_roi_and_traces_example_GLM(cell_metadata,
                                            cell_dropouts,
                                            cell_weights,
                                            weights_features,
                                            kernels,
                                            dropout_features,
                                            experiments_table,
                                            data_type,
                                            save_dir=None,
                                            folder=None):
    """
    This function will plot the following panels:
        cell ROI masks matched across sessions for a given cell_specimen_id,
        change and omission triggered average respones across sessions,
        image locked running and pupil if included 'running' and 'pupil' in included weights_features,
        dropout scores across features and sessions as a heatmap,
        kernels weights across sessions for the kernels in weights_features.
    Plots the ROI masks and cell traces for a cell matched across sessions, along with dropout scores and weights for images, hits, misses and omissions
    Cell_metadata is a subset of the ophys_cells_table limited to the cell_specimen_id of interest
    cell_dropouts is a subset of the results_pivoted version of GLM output limited to cell_specimen_id of interest
    cell_weights is a subset of the weights matrix from GLM limited to cell_specimen_id of interest
    all input dataframes must be limited to last familiar and second novel active (i.e. max of one session per type)
    if one session type is missing, the max projection but no ROI will be plotted and the traces and weights will be missing for that experience level
    """

    if len(cell_metadata.cell_specimen_id.unique()) > 1:
        print(
            'There is more than one cell_specimen_id in the provided cell_metadata table'
        )
        print('Please limit input to a single cell_specimen_id')

    # set up plotting for each experience level
    experience_levels = ['Familiar', 'Novel 1', 'Novel >1']
    colors = utils.get_experience_level_colors()
    n_exp_levels = len(experience_levels)
    # get relevant info for this cell
    cell_metadata = cell_metadata.sort_values(by='experience_level')
    cell_specimen_id = cell_metadata.cell_specimen_id.unique()[0]
    ophys_container_id = cell_metadata.ophys_container_id.unique()[0]
    # need to get all experiments for this container, not just for this cell
    ophys_experiment_ids = experiments_table[
        experiments_table.ophys_container_id ==
        ophys_container_id].index.values
    n_expts = len(ophys_experiment_ids)
    if n_expts > 3:
        print(
            'There are more than 3 experiments for this cell. There should be a max of 1 experiment per experience level'
        )
        print('Please limit input to only one experiment per experience level')

    # set up labels for different trace types
    if data_type == 'dff':
        ylabel = 'dF/F'
    else:
        ylabel = 'response'

    # number of columns is one for each experience level,
    # plus additional columns for stimulus and omission traces, and running and pupil averages (TBD)
    extra_cols = 2
    if 'running' in weights_features:
        extra_cols += 1
    if 'running' in weights_features:
        extra_cols += 1
    n_cols = n_exp_levels + extra_cols
    print(extra_cols, 'extra cols')

    figsize = (3.5 * n_cols, 6)
    fig, ax = plt.subplots(2, n_cols, figsize=figsize)
    ax = ax.ravel()

    print('cell_specimen_id:', cell_specimen_id)
    # loop through experience levels for this cell
    for e, experience_level in enumerate(experience_levels):
        print('experience_level:', experience_level)

        # get ophys_experiment_id for this experience level
        # experiments_table must only include one experiment per experience level for a given container
        ophys_experiment_id = experiments_table[
            (experiments_table.ophys_container_id == ophys_container_id)
            & (experiments_table.experience_level == experience_level
               )].index.values[0]
        print('ophys_experiment_id:', ophys_experiment_id)
        ind = experience_levels.index(experience_level)
        color = colors[ind]

        # load dataset for this experiment
        dataset = loading.get_ophys_dataset(
            ophys_experiment_id, get_extended_stimulus_presentations=False)

        try:  # attempt to generate plots for this cell in this this experience level. if cell does not have this exp level, skip
            # plot ROI mask for this experiment
            ct = dataset.cell_specimen_table.copy()
            cell_roi_id = ct.loc[
                cell_specimen_id].cell_roi_id  # typically will fail here if the cell_specimen_id isnt in the session
            roi_masks = dataset.roi_masks.copy(
            )  # save this to get approx ROI position if subsequent session is missing the ROI (fails if the first session is the one missing the ROI)
            ax[e] = sf.plot_cell_zoom(dataset.roi_masks,
                                      dataset.max_projection,
                                      cell_roi_id,
                                      spacex=50,
                                      spacey=50,
                                      show_mask=True,
                                      ax=ax[e])
            ax[e].set_title(experience_level, color=color)

            # get change responses and plot on second to next axis after ROIs (there are n_expts # of ROIs)
            window = [-1, 1.5]  # window around event
            sdf = loading.get_stimulus_response_df(dataset,
                                                   time_window=window,
                                                   interpolate=True,
                                                   output_sampling_rate=30,
                                                   data_type=data_type,
                                                   event_type='changes',
                                                   load_from_file=True)
            cell_data = sdf[(sdf.cell_specimen_id == cell_specimen_id)
                            & (sdf.is_change == True)]

            ax[n_expts] = utils.plot_mean_trace(
                cell_data.trace.values,
                cell_data.trace_timestamps.values[0],
                ylabel=ylabel,
                legend_label=None,
                color=color,
                interval_sec=1,
                xlim_seconds=window,
                plot_sem=True,
                ax=ax[n_expts])
            ax[n_expts] = utils.plot_flashes_on_trace(
                ax[n_expts],
                cell_data.trace_timestamps.values[0],
                change=True,
                omitted=False)
            ax[n_expts].set_title('changes')

            # get omission responses and plot on last axis
            sdf = loading.get_stimulus_response_df(dataset,
                                                   time_window=window,
                                                   interpolate=True,
                                                   output_sampling_rate=30,
                                                   data_type=data_type,
                                                   event_type='omissions',
                                                   load_from_file=True)
            cell_data = sdf[(sdf.cell_specimen_id == cell_specimen_id)
                            & (sdf.omitted == True)]

            ax[n_expts + 1] = utils.plot_mean_trace(
                cell_data.trace.values,
                cell_data.trace_timestamps.values[0],
                ylabel=ylabel,
                legend_label=None,
                color=color,
                interval_sec=1,
                xlim_seconds=window,
                plot_sem=True,
                ax=ax[n_expts + 1])
            ax[n_expts + 1] = utils.plot_flashes_on_trace(
                ax[n_expts + 1],
                cell_data.trace_timestamps.values[0],
                change=False,
                omitted=True)
            ax[n_expts + 1].set_title('omissions')

            if 'running' in weights_features:
                pass
            if 'pupil' in weights_features:
                pass

        except BaseException:  # plot area of max projection where ROI would have been if it was in this session
            # plot the max projection image with the xy location of the previous ROI
            # this will fail if the familiar session is the one without the cell matched
            print('no cell ROI for', experience_level)
            ax[e] = sf.plot_cell_zoom(roi_masks,
                                      dataset.max_projection,
                                      cell_roi_id,
                                      spacex=50,
                                      spacey=50,
                                      show_mask=False,
                                      ax=ax[e])
            ax[e].set_title(experience_level)

        # try: # try plotting GLM outputs for this experience level
        if 'running' in weights_features:
            pass
        if 'pupil' in weights_features:
            pass

        # GLM plots start after n_expts for each ROI mask, plus n_extra_cols more axes for omission and change responses (and running and pupil if added)
        # plus one more axes for dropout heatmaps
        i = n_expts + extra_cols + 1

        # weights
        exp_weights = cell_weights[cell_weights.experience_level ==
                                   experience_level]

        # image kernels
        image_weights = []
        for f, feature in enumerate(
                weights_features[:8]):  # first 8 are images
            image_weights.append(exp_weights[feature + '_weights'].values[0])
        mean_image_weights = np.mean(image_weights, axis=0)

        # GLM output is all resampled to 30Hz now
        frame_rate = 31
        t_array = get_t_array_for_kernel(kernels, feature, frame_rate)
        ax[i].plot(t_array, mean_image_weights, color=color)
        ax[i].set_ylabel('weight')
        ax[i].set_title('images')
        ax[i].set_xlabel('time (s)')
        ax_to_share = i

        i += 1
        # all other kernels
        for f, feature in enumerate(weights_features[8:]):
            kernel_weights = exp_weights[feature + '_weights'].values[0]
            if feature == 'omissions':
                n_frames_to_clip = int(
                    kernels['omissions']['length'] * frame_rate) + 1
                kernel_weights = kernel_weights[:n_frames_to_clip]
            t_array = get_t_array_for_kernel(kernels, feature, frame_rate)
            ax[i + f].plot(t_array, kernel_weights, color=color)
            ax[i + f].set_ylabel('')
            ax[i + f].set_title(feature)
            ax[i + f].set_xlabel('time (s)')
            ax[i + f].get_shared_y_axes().join(ax[i + f], ax[ax_to_share])

        # except:
        #     print('could not plot GLM kernels for', experience_level)

    # try:
    # plot dropout score heatmaps
    i = n_expts + extra_cols  # change to extra_cols = 4 if running and pupil are added
    # cell_dropouts['cre_line'] = cre_line
    cell_dropouts = cell_dropouts.groupby(['experience_level']).mean()
    if 'ophys_experiment_id' in cell_dropouts.keys():
        cell_dropouts = cell_dropouts.drop(columns='ophys_experiment_id')
    if 'cell_specimen_id' in cell_dropouts.keys():
        cell_dropouts = cell_dropouts.drop(columns='cell_specimen_id')
    cell_dropouts = cell_dropouts[dropout_features]  # order dropouts properly
    dropouts = cell_dropouts.T
    if len(np.where(dropouts < 0)[0]) > 0:
        vmin = -1
        cmap = 'RdBu'
    else:
        vmin = 0
        cmap = 'Blues'
    ax[i] = sns.heatmap(dropouts,
                        cmap=cmap,
                        vmin=vmin,
                        vmax=1,
                        ax=ax[i],
                        cbar=False)
    # ax[i].set_title('coding scores')
    ax[i].set_yticklabels(dropouts.index.values, rotation=0, fontsize=14)
    ax[i].set_xticklabels(dropouts.columns.values, rotation=90, fontsize=14)
    ax[i].set_ylim(0, dropouts.shape[0])
    ax[i].set_xlabel('')

    metadata_string = utils.get_container_metadata_string(dataset.metadata)

    fig.tight_layout()
    fig.subplots_adjust(hspace=0.6, wspace=0.7)
    fig.suptitle(str(cell_specimen_id) + '_' + metadata_string,
                 x=0.53,
                 y=1.02,
                 horizontalalignment='center',
                 fontsize=16)

    if save_dir:
        print('saving plot for', cell_specimen_id)
        utils.save_figure(
            fig, figsize, save_dir, folder,
            str(cell_specimen_id) + '_' + metadata_string + '_' + data_type)
        print('saved')
Пример #28
0
def plot_behavior_and_pop_avg_mesoscope(ophys_session_id,
                                        xlim_seconds=None,
                                        save_figure=True):
    if xlim_seconds is None:
        suffix = ''
    elif xlim_seconds[1] < 2000:
        suffix = 'early'
    elif xlim_seconds[1] > -2000:
        suffix = 'late'

    experiments_table = loading.get_filtered_ophys_experiment_table()
    experiment_ids = experiments_table[
        experiments_table.ophys_session_id == ophys_session_id].sort_values(
            by='date_of_acquisition').index.values
    experiment_id = experiment_ids[0]

    dataset = loading.get_ophys_dataset(experiment_id)

    colors = sns.color_palette()
    figsize = (15, 8)
    fig, ax = plt.subplots(3, 1, figsize=figsize, sharex=True)
    try:
        ax[0] = plot_behavior_model_weights(dataset,
                                            xlim_seconds=xlim_seconds,
                                            plot_stimuli=True,
                                            ax=ax[0])
    except BaseException:
        print('no behavior model output for', dataset.ophys_experiment_id)
    ax[1] = plot_behavior(dataset.ophys_experiment_id,
                          xlim_seconds=xlim_seconds,
                          plot_stimuli=True,
                          ax=ax[1])

    label = dataset.metadata['targeted_structure'] + ', ' + str(
        dataset.metadata['imaging_depth'])
    ax[2] = plot_traces(dataset,
                        include_cell_traces=False,
                        plot_stimuli=True,
                        xlim_seconds=xlim_seconds,
                        label=label,
                        color=colors[0],
                        ax=ax[2])
    for i, experiment_id in enumerate(experiment_ids[1:]):
        dataset = loading.get_ophys_dataset(experiment_id)
        label = dataset.metadata['targeted_structure'] + ', ' + str(
            dataset.metadata['imaging_depth'])
        ax[2] = plot_traces(dataset,
                            include_cell_traces=False,
                            plot_stimuli=False,
                            xlim_seconds=xlim_seconds,
                            label=label,
                            color=colors[i + 1],
                            ax=ax[2])
    ax[2].legend(fontsize='x-small', loc='upper left')
    plt.subplots_adjust(wspace=0, hspace=0.1)
    ax[0].set_title(dataset.metadata_string)

    if save_figure:
        save_dir = os.path.abspath(
            os.path.join(loading.get_qc_plots_dir(), 'timeseries_plots'))
        utils.save_figure(fig, figsize, save_dir,
                          'behavior_traces_population_average_mesoscope',
                          dataset.metadata_string + '_' + suffix)
def plot_across_session_responses(ophys_container_id, cell_specimen_id, use_events=False, save_figure=True):
    """
    Generates plots characterizing single cell activity in response to stimulus, omissions, and changes.
    Compares across all sessions in a container for each cell, including the ROI mask across days.
    Useful to validate cell matching as well as examine changes in activity profiles over days.
    """
    experiments_table = data_loading.get_filtered_ophys_experiment_table(release_data_only=True)
    container_expts = experiments_table[experiments_table.ophys_container_id == ophys_container_id]
    expts = np.sort(container_expts.index.values)
    if use_events:
        ylabel = 'response'
        suffix = '_events'
    else:
        ylabel = 'dF/F'
        suffix = ''

    n = len(expts)
    figsize = (25, 20)
    fig, ax = plt.subplots(6, n, figsize=figsize)
    ax = ax.ravel()
    print('ophys_container_id:', ophys_container_id)
    for i, ophys_experiment_id in enumerate(expts):
        print('ophys_experiment_id:', ophys_experiment_id)
        try:

            dataset = data_loading.get_ophys_dataset(ophys_experiment_id, include_invalid_rois=False)
            if cell_specimen_id in dataset.dff_traces.index:
                analysis = ResponseAnalysis(dataset, use_events=use_events, use_extended_stimulus_presentations=False)
                sdf = ut.get_mean_df(analysis.get_response_df(df_name='stimulus_response_df'), analysis=analysis,
                                     conditions=['cell_specimen_id', 'is_change', 'image_name'], flashes=True, omitted=False,
                                     get_reliability=False, get_pref_stim=True, exclude_omitted_from_pref_stim=True)
                odf = ut.get_mean_df(analysis.get_response_df(df_name='omission_response_df'), analysis=analysis,
                                     conditions=['cell_specimen_id'], flashes=False, omitted=True,
                                     get_reliability=False, get_pref_stim=False, exclude_omitted_from_pref_stim=False)
                tdf = ut.get_mean_df(analysis.get_response_df(df_name='trials_response_df'), analysis=analysis,
                                     conditions=['cell_specimen_id', 'go', 'hit', 'change_image_name'], flashes=False, omitted=False,
                                     get_reliability=False, get_pref_stim=True, exclude_omitted_from_pref_stim=True)

                ct = dataset.cell_specimen_table.copy()
                cell_roi_id = ct.loc[cell_specimen_id].cell_roi_id
                ax[i] = sf.plot_cell_zoom(dataset.roi_masks, dataset.max_projection, cell_roi_id,
                                          spacex=20, spacey=20, show_mask=True, ax=ax[i])
                ax[i].set_title(container_expts.loc[ophys_experiment_id].session_type[6:])

                colors = sns.color_palette('hls', 8) + [(0.5, 0.5, 0.5)]

                window = rp.get_default_stimulus_response_params()["window_around_timepoint_seconds"]
                cell_data = sdf[(sdf.cell_specimen_id == cell_specimen_id) & (sdf.is_change == False)]
                for c, image_name in enumerate(np.sort(cell_data.image_name.unique())):
                    ax[i + n] = sf.plot_mean_trace_from_mean_df(cell_data[cell_data.image_name == image_name],
                                                                frame_rate=analysis.ophys_frame_rate, ylabel=ylabel,
                                                                legend_label=image_name, color=colors[c], interval_sec=0.5,
                                                                xlims=window, ax=ax[i + n])
                ax[i + n] = sf.plot_flashes_on_trace(ax[i + n], analysis, window=window, trial_type=None, omitted=False, alpha=0.15, facecolor='gray')
                ax[i + n].set_title(container_expts.loc[ophys_experiment_id].session_type[6:] + '\n image response')

                analysis = ResponseAnalysis(dataset, use_events=False, use_extended_stimulus_presentations=True)
                tmp = analysis.get_response_df(df_name='stimulus_response_df')
                tmp['running'] = [True if run_speed > 2 else False for run_speed in tmp.mean_running_speed.values]
                sdf = ut.get_mean_df(tmp, analysis=analysis,
                                     conditions=['cell_specimen_id', 'is_change', 'image_name', 'running'], flashes=True, omitted=False,
                                     get_reliability=False, get_pref_stim=True, exclude_omitted_from_pref_stim=False)

                cell_data = sdf[(sdf.cell_specimen_id == cell_specimen_id) & (sdf.is_change == False) & (sdf.pref_stim == True)]
                run_colors = [sns.color_palette()[3], sns.color_palette()[2]]
                for c, running in enumerate(np.sort(cell_data.running.unique())):
                    if len(cell_data[cell_data.running == running]) > 0:
                        ax[i + (n * 2)] = sf.plot_mean_trace_from_mean_df(cell_data[cell_data.running == running],
                                                                          frame_rate=analysis.ophys_frame_rate, ylabel=ylabel,
                                                                          legend_label=running, color=run_colors[c], interval_sec=0.5,
                                                                          xlims=window, ax=ax[i + (n * 2)])
                ax[i + (n * 2)].legend(fontsize='xx-small', title='running', title_fontsize='xx-small')
                ax[i + (n * 2)] = sf.plot_flashes_on_trace(ax[i + (n * 2)], analysis, window=window, trial_type=None, omitted=False, alpha=0.15, facecolor='gray')
                ax[i + (n * 2)].set_title(container_expts.loc[ophys_experiment_id].session_type[6:] + '\n image response')

                window = rp.get_default_omission_response_params()["window_around_timepoint_seconds"]
                cell_data = odf[(odf.cell_specimen_id == cell_specimen_id)]
                ax[i + (n * 3)] = sf.plot_mean_trace_from_mean_df(cell_data,
                                                                  frame_rate=analysis.ophys_frame_rate, ylabel=ylabel,
                                                                  legend_label=image_name, color='gray', interval_sec=1,
                                                                  xlims=window, ax=ax[i + (n * 3)])
                ax[i + (n * 3)] = sf.plot_flashes_on_trace(ax[i + (n * 3)], analysis, window=window, trial_type=None, omitted=True, alpha=0.15, facecolor='gray')
                ax[i + (n * 3)].set_title(container_expts.loc[ophys_experiment_id].session_type[6:] + '\n omission response')

                window = rp.get_default_trial_response_params()["window_around_timepoint_seconds"]
                cell_data = tdf[(tdf.cell_specimen_id == cell_specimen_id) & (tdf.go == True) & (tdf.pref_stim == True)]
                hit_colors = [sns.color_palette()[2], sns.color_palette()[3]]
                for c, hit in enumerate([True, False]):
                    if len(cell_data[cell_data.hit == hit]) > 0:
                        ax[i + (n * 4)] = sf.plot_mean_trace_from_mean_df(cell_data[cell_data.hit == hit],
                                                                          frame_rate=analysis.ophys_frame_rate, ylabel=ylabel,
                                                                          legend_label=hit, color=hit_colors[c], interval_sec=1,
                                                                          xlims=window, ax=ax[i + (n * 4)])
                ax[i + (n * 4)].legend(fontsize='xx-small', title='hit', title_fontsize='xx-small')
                ax[i + (n * 4)] = sf.plot_flashes_on_trace(ax[i + (n * 4)], analysis, window=window, trial_type='go', omitted=False, alpha=0.15, facecolor='gray')
                ax[i + (n * 4)].set_title(container_expts.loc[ophys_experiment_id].session_type[6:] + '\n change response')

                fig.tight_layout()
                fig.suptitle(str(cell_specimen_id) + '_' + dataset.metadata_string, x=0.5, y=1.01,
                             horizontalalignment='center')
        except Exception as e:
            print('problem for cell_specimen_id:', cell_specimen_id, ', ophys_experiment_id:', ophys_experiment_id)
            print(e)
    if save_figure:
        save_dir = utils.get_single_cell_plots_dir()
        utils.save_figure(fig, figsize, save_dir, 'across_session_responses', str(
            cell_specimen_id) + '_' + dataset.metadata_string + '_across_session_responses' + suffix)
        plt.close()
Пример #30
0
def plot_mean_response_by_epoch(df, metric='mean_response', horizontal=True, ymin=0, ylabel='mean response', estimator=np.mean,
                                save_dir=None, folder='epochs', suffix='', ax=None):
    """
    Plots the mean metric value across 10 minute epochs within a session
    :param df: dataframe of cell activity with one row per cell_specimen_id / ophys_experiment_id
                must include columns 'cell_type', 'experience_level', 'epoch', and a column for the metric provided (ex: 'mean_response')
    :param metric: metric value to average over epochs; must be a column of df
    :param save_dir: top level directory to save figure to
    :param folder: folder within save_dir to save figure to; will create folder if it doesnt exist
    :param suffix: string to append at end of saved filename
    :return:
    """

    # get rid of short 7th epoch (just a few mins at end of session)
    df = df[df.epoch != 6]

    # add experience epoch column in case it doesnt already exist
    if 'experience_epoch' not in df.keys():
        def merge_experience_epoch(row):
            return row.experience_level + ' epoch ' + str(int(row.epoch) + 1)
        df['experience_epoch'] = df[['experience_level', 'epoch']].apply(axis=1, func=merge_experience_epoch)

    xticks = [experience_epoch.split(' ')[-1] for experience_epoch in np.sort(df.experience_epoch.unique())]

    cell_types = np.sort(df.cell_type.unique())[::-1]
    experience_epoch = np.sort(df.experience_epoch.unique())
    experience_levels = np.sort(df.experience_level.unique())

    palette = utils.get_experience_level_colors()
    if ax is None:
        format_fig = True
        if horizontal:
            figsize = (13, 3.5)
            fig, ax = plt.subplots(1, 3, figsize=figsize, sharex=False, sharey=True)
        else:
            figsize = (5, 10)
            fig, ax = plt.subplots(3, 1, figsize=figsize, sharex=True, sharey=True)
    else:
        format_fig = False

    for i, cell_type in enumerate(cell_types):
        try:
            data = df[df.cell_type == cell_type]
            ax[i] = sns.pointplot(data=data, x='experience_epoch', y=metric, hue='experience_level', hue_order=experience_levels,
                                  order=experience_epoch, palette=palette, ax=ax[i], estimator=estimator)
            if ymin is not None:
                ax[i].set_ylim(ymin=ymin)
            ax[i].set_title('')
            ax[i].set_ylabel(ylabel)
            # ax[i].set_xlabel('')
            ax[i].get_legend().remove()
            ax[i].set_xticklabels(xticks, fontsize=13)
            # ax[i].vlines(x=5.5, ymin=0, ymax=1, color='gray', linestyle='--')
            # ax[i].vlines(x=11.5, ymin=0, ymax=1, color='gray', linestyle='--')
            if horizontal:
                ax[i].set_xlabel('10 min epoch within session', fontsize=14)
            else:
                ax[i].set_xlabel('')
        except Exception as e:
            print(e)
    ax[i].set_xlabel('10 min epoch within session', fontsize=14)
    if format_fig:
        # plt.suptitle(metric + ' over time', x=0.52, y=1.03, fontsize=18)
        fig.tight_layout()
    if save_dir:
        fig_title = metric + '_epochs' + suffix
        utils.save_figure(fig, figsize, save_dir, folder, fig_title)
    return ax