예제 #1
0
def plot_behavior_events_trace(session,
                               xmin=360,
                               length=3,
                               ax=None,
                               save_dir=None):
    xmax = xmin + 60 * length
    interval = 20
    if ax is None:
        figsize = (15, 4)
        fig, ax = plt.subplots(figsize=figsize)
    ax.plot(session.running_speed.timestamps,
            session.running_speed.values,
            color=sns.color_palette()[0])
    ax = add_stim_color_span(session, ax, xlim=[xmin, xmax])
    ax = plot_behavior_events(session, ax)
    ax = restrict_axes(xmin, xmax, interval, ax)
    ax.set_ylabel('running speed (cm/s)')
    ax.set_xlabel('time (sec)')
    if save_dir:
        fig.tight_layout()
        ut.save_figure(
            fig, figsize, save_dir, 'behavior_events',
            str(session.metadata['ophys_experiment_id']) + '_' + str(xmin))
        plt.close()
    return ax
예제 #2
0
def plot_max_proj_and_roi_masks(session, save_dir=None):
    figsize = (15, 5)
    fig, ax = plt.subplots(1, 3, figsize=figsize)
    ax = ax.ravel()

    ax[0].imshow(session.max_projection,
                 cmap='gray',
                 vmin=0,
                 vmax=np.amax(session.max_projection))
    ax[0].axis('off')
    ax[0].set_title('max intensity projection')

    ax[1].imshow(session.segmentation_mask_image, cmap='gray')
    ax[1].set_title('roi masks')
    ax[1].axis('off')

    ax[2].imshow(session.max_projection,
                 cmap='gray',
                 vmin=0,
                 vmax=np.amax(session.max_projection))
    ax[2].axis('off')
    ax[2].set_title(str(session.metadata['ophys_experiment_id']))

    tmp = session.segmentation_mask_image.data.copy()
    mask = np.empty(session.segmentation_mask_image.data.shape, dtype=np.float)
    mask[:] = np.nan
    mask[tmp > 0] = 1
    cax = ax[2].imshow(mask, cmap='hsv', alpha=0.4, vmin=0, vmax=1)

    if save_dir:
        ut.save_figure(fig, figsize, save_dir, 'roi_masks',
                       str(session.metadata['ophys_experiment_id']))
예제 #3
0
def plot_mean_trace_heatmap(mean_df,
                            ax=None,
                            save_dir=None,
                            window=[-4, 8],
                            interval_sec=2):
    """
    There must be only one row per cell in the input df.
    For example, if it is a mean of the trial_response_df, select only trials where go=True before passing to this function.
    """
    data = mean_df[mean_df.pref_stim == True].copy()
    if ax is None:
        figsize = (3, 6)
        fig, ax = plt.subplots(1, 1, figsize=figsize)

    order = np.argsort(data.mean_response.values)[::-1]
    cells = data.cell_specimen_id.unique()[order]
    len_trace = len(data.mean_trace.values[0])
    response_array = np.empty((len(cells), len_trace))
    for x, cell_specimen_id in enumerate(cells):
        tmp = data[data.cell_specimen_id == cell_specimen_id]
        if len(tmp) >= 1:
            trace = tmp.mean_trace.values[0]
        else:
            trace = np.empty((len_trace))
            trace[:] = np.nan
        response_array[x, :] = trace

    sns.heatmap(data=response_array,
                vmin=0,
                vmax=np.percentile(response_array, 99),
                ax=ax,
                cmap='magma',
                cbar=False)
    xticks, xticklabels = ut.get_xticks_xticklabels(trace,
                                                    31.,
                                                    interval_sec=interval_sec,
                                                    window=window)
    ax.set_xticks(xticks)
    ax.set_xticklabels([int(x) for x in xticklabels])
    if response_array.shape[0] < 50:
        interval = 10
    else:
        interval = 50
    ax.set_yticks(np.arange(0, response_array.shape[0], interval))
    ax.set_yticklabels(np.arange(0, response_array.shape[0], interval))
    ax.set_xlabel('time after change (s)', fontsize=16)
    ax.set_ylabel('cells')

    if save_dir:
        fig.tight_layout()
        ut.save_figure(fig, figsize, save_dir, 'experiment_summary',
                       'mean_trace_heatmap_' + condition + suffix)
    return ax
예제 #4
0
def plot_experiment_summary_figure(session, save_dir=None):
    import allensdk.brain_observatory.behavior.swdb.utilities as ut

    meta = session.metadata
    title = meta['driver_line'][0] + ', ' + meta['targeted_structure'] + ', ' + str(meta['imaging_depth']) + ', ' + \
            session.task_parameters['stage']

    interval_seconds = 600
    ophys_frame_rate = int(session.metadata['ophys_frame_rate'])

    figsize = [2 * 11, 2 * 8.5]
    fig = plt.figure(figsize=figsize, facecolor='white')

    ax = placeAxesOnGrid(fig, dim=(1, 1), xspan=(.0, .2), yspan=(0, .2))
    ax.imshow(session.max_projection, cmap='gray')
    ax.set_title('max intensity projection')
    ax.axis('off')

    ax = placeAxesOnGrid(fig, dim=(1, 1), xspan=(0, .18), yspan=(.24, .4))
    trials = session.trials.copy()
    trials = trials[trials.reward_rate > 1]
    plot_transitions_response_heatmap(trials, ax=ax)

    ax = placeAxesOnGrid(fig, dim=(1, 1), xspan=(.24, .86), yspan=(0, .26))
    ax = plot_traces_heatmap(session, ax=ax)
    ax.set_title(title)

    ax = placeAxesOnGrid(fig, dim=(1, 1), xspan=(.28, .92), yspan=(.32, .44))
    ax.plot(session.running_speed.timestamps, session.running_speed.values)
    ax.set_xlabel('time (seconds)')
    ax.set_ylabel('running speed\n(cm/s)')

    ax = placeAxesOnGrid(fig, dim=(1, 1), xspan=(.86, 1.), yspan=(0, .2))
    image_index = 0
    ax.imshow(session.stimulus_templates[image_index, :, :], cmap='gray')
    st = session.stimulus_presentations.copy()
    image_name = st[st.image_index == image_index].image_name.values[0]
    ax.set_title(image_name)
    ax.axis('off')

    ax = placeAxesOnGrid(fig, dim=(1, 1), xspan=(.0, .17), yspan=(.54, .99))
    ax = plot_lick_raster(session.trials, ax=ax)

    ax = placeAxesOnGrid(fig, dim=(1, 1), xspan=(.24, .42), yspan=(.54, .99))
    fr = session.flash_response_df
    mdf = ut.get_mean_df(fr, conditions=['cell_specimen_id', 'image_name'])
    plot_mean_image_response_heatmap(mdf, title=None, ax=ax)

    ax = placeAxesOnGrid(fig, dim=(1, 1), xspan=(.52, .68), yspan=(.54, .99))
    tr = session.trial_response_df.copy()
    mdf = ut.get_mean_df(tr[tr.go], conditions=['cell_specimen_id'])
    mdf['pref_stim'] = True
    ax = plot_mean_trace_heatmap(mdf, ax=ax, window=[-4, 8], interval_sec=2)
    ax.set_title('mean trace for pref image')
    ax.set_ylabel('cells')

    ax = placeAxesOnGrid(fig, dim=(1, 1), xspan=(.76, .98), yspan=(.5, .62))
    ax.plot(session.trials.reward_rate)
    ax.set_ylabel('reward rate')
    ax.set_xlabel('trials')

    ax = placeAxesOnGrid(fig, dim=(1, 1), xspan=(.76, 0.98), yspan=(.68, .8))
    plot_behavior_segment(session, xlims=[620, 640], ax=ax)

    ax = placeAxesOnGrid(fig, dim=(1, 1), xspan=(.76, .98), yspan=(.86, .99))
    traces = tr[(tr.go == True)].dff_trace.values
    ax = ut.plot_mean_trace(traces, window=[-4, 8], ax=ax)
    ax = ut.plot_flashes_on_trace(ax, window=[-4, 8], go_trials_only=True)
    ax.set_xlabel('time after change (sec)')
    ax.set_ylabel('mean dF/F')

    fig.tight_layout()

    if save_dir:
        fig.tight_layout()
        ut.save_figure(fig, figsize, save_dir, 'experiment_summary',
                       str(experiment_id))
예제 #5
0
def plot_mean_image_response_heatmap(mean_df,
                                     title=None,
                                     ax=None,
                                     save_dir=None):
    df = mean_df.copy()
    if 'change_image_name' in df.keys():
        image_key = 'change_image_name'
    else:
        image_key = 'image_name'
    images = np.sort(df[image_key].unique())
    cell_list = []
    for image in images:
        tmp = df[(df[image_key] == image) & (df.pref_stim == True)]
        order = np.argsort(tmp.mean_response.values)[::-1]
        cell_ids = list(tmp.cell_specimen_id.values[order])
        cell_list = cell_list + cell_ids

    response_matrix = np.empty((len(cell_list), len(images)))
    for i, cell in enumerate(cell_list):
        responses = []
        for image in images:
            response = df[(df.cell_specimen_id == cell)
                          & (df[image_key] == image)].mean_response.values[0]
            responses.append(response)
        response_matrix[i, :] = np.asarray(responses)

    if ax is None:
        figsize = (4, 7)
        fig, ax = plt.subplots(figsize=figsize)

    vmax = 0.3
    label = 'mean dF/F'
    ax = sns.heatmap(response_matrix,
                     cmap='magma',
                     linewidths=0,
                     linecolor='white',
                     square=False,
                     vmin=0,
                     vmax=vmax,
                     robust=True,
                     cbar_kws={
                         "drawedges": False,
                         "shrink": 1,
                         "label": label
                     },
                     ax=ax)

    if title is None:
        title = 'mean response by image'
    ax.set_title(title, va='bottom', ha='center')
    ax.set_xticklabels(images, rotation=90)
    ax.set_ylabel('cells')
    if response_matrix.shape[0] < 50:
        interval = 10
    else:
        interval = 50
    ax.set_yticks(np.arange(0, response_matrix.shape[0], interval))
    ax.set_yticklabels(np.arange(0, response_matrix.shape[0], interval))
    if save_dir:
        fig.tight_layout()
        ut.save_figure(fig, figsize, save_dir, 'experiment_summary',
                       'mean_image_response_heatmap' + suffix)
예제 #6
0
def plot_example_traces_and_behavior(session,
                                     xmin_seconds,
                                     length_mins,
                                     cell_label=False,
                                     save_dir=None):
    traces = np.stack(session.dff_traces.dff.values)
    cell_indices = ut.get_active_cell_indices(traces)

    interval_seconds = 10
    xmax_seconds = xmin_seconds + (length_mins * 60) + 1
    xlim = [xmin_seconds, xmax_seconds]

    figsize = (14, 10)
    fig, ax = plt.subplots(len(cell_indices) + 1, 1, figsize=figsize)
    ax = ax.ravel()

    ymins = []
    ymaxs = []
    for i, cell_index in enumerate(cell_indices):
        ax[i].tick_params(reset=True,
                          which='both',
                          bottom='off',
                          top='off',
                          right='off',
                          left='off',
                          labeltop='off',
                          labelright='off',
                          labelleft='off',
                          labelbottom='off')
        ax[i] = plot_trace(session.ophys_timestamps,
                           traces[cell_index, :],
                           ax=ax[i],
                           title='',
                           ylabel=str(cell_index),
                           color=[.5, .5, .5])
        ax[i] = add_stim_color_span(session, ax=ax[i], xlim=xlim)
        ax[i] = restrict_axes(xmin_seconds,
                              xmax_seconds,
                              interval_seconds,
                              ax=ax[i])
        ax[i].set_xticks([])
        ax[i].set_xlabel('')
        ax[i].set_xlim(xlim)
        ymin, ymax = ax[i].get_ylim()
        ymins.append(ymin)
        ymaxs.append(ymax)
        if cell_label:
            ax[i].set_ylabel('cell ' + str(i), fontsize=12)
        else:
            ax[i].set_ylabel('')
        ax[i].set_yticks([])
        sns.despine(ax=ax[i], left=True, bottom=True)
        ymin, ymax = ax[i].get_ylim()
        if 'Vip' in session.metadata['full_genotype']:
            ax[i].vlines(x=xmin_seconds, ymin=0, ymax=2, linewidth=4)
            ax[i].set_ylim(ymin=-0.5, ymax=5)
        elif 'Slc' in session.metadata['full_genotype']:
            ax[i].vlines(x=xmin_seconds, ymin=0, ymax=1, linewidth=4)
            ax[i].set_ylim(ymin=-0.5, ymax=3)
        ax[i].get_xaxis().set_ticks([])
        ax[i].get_yaxis().set_ticks([])
    ax[i].tick_params(which='both',
                      bottom='off',
                      top='off',
                      right='off',
                      left='off',
                      labeltop='off',
                      labelright='off',
                      labelleft='off',
                      labelbottom='off')
    ax[i].set_xticklabels('')

    i += 1
    ax[i].tick_params(axis="x",
                      bottom=True,
                      top=False,
                      labelbottom=True,
                      labeltop=False)
    ax[i].plot(session.running_speed.timestamps,
               session.running_speed.values,
               color=sns.color_palette()[0])
    ax[i] = plot_behavior_events(session, ax=ax[i])
    ax[i] = add_stim_color_span(session, ax=ax[i], xlim=xlim)
    ax[i] = restrict_axes(xmin_seconds,
                          xmax_seconds,
                          interval_seconds,
                          ax=ax[i])
    ax[i].set_xlim(xlim)
    ax[i].set_ylabel('run speed\n(cm/s)', fontsize=12)
    sns.despine(ax=ax[i], left=True, bottom=True)
    ax[i].set_yticklabels('')
    xticks = np.arange(xmin_seconds, xmax_seconds, interval_seconds)
    ax[i].set_xticks(xticks)
    ax[i].set_xticklabels(xticks)
    ax[i].set_xlabel('time (seconds)')

    ax[0].set_title(
        str(session.metadata['ophys_experiment_id']) + '_' +
        session.metadata['full_genotype'].split('-')[0])
    plt.subplots_adjust(wspace=0, hspace=0)
    plt.subplots_adjust(bottom=0.2)

    if save_dir:
        ut.save_figure(
            fig, figsize, save_dir, 'example_traces',
            str(session.metadata['ophys_experiment_id']) + '_' + str(xlim[0]))