コード例 #1
0
def get_timestamps_for_response_df_type(cache, experiment_id, df_name):
    """
    get timestamps from response_df
    """

    dataset = cache.get_behavior_ophys_experiment(experiment_id)
    analysis = ResponseAnalysis(dataset)
    response_df = analysis.get_response_df(df_name=df_name)
    timestamps = response_df.trace_timestamps.values[0]
    print(len(timestamps))

    return timestamps
コード例 #2
0
def summarize_responses(experiment_id):
    dataset = loading.get_ophys_dataset(experiment_id)
    analysis = ResponseAnalysis(dataset,
                                overwrite_analysis_files=False,
                                dataframe_format='tidy',
                                use_extended_stimulus_presentations=True)

    valid_cells = list(
        dataset.cell_specimen_table.query('valid_roi==True').index.values)
    summaries = []
    for df_type in ['omission', 'stimulus', 'trials']:
        summary = (getattr(analysis, '{}_response_df'.format(df_type)).query(
            'cell_specimen_id in @valid_cells').drop_duplicates([
                c for c in
                ['cell_specimen_id', 'stimulus_presentations_id', 'trials_id']
                if c in getattr(analysis, '{}_response_df'.format(df_type))
            ]).groupby(['cell_specimen_id', 'engagement_state'
                        ])[['mean_response',
                            'mean_baseline']].mean().reset_index())
        for key in analysis.dataset.metadata.keys():
            val = analysis.dataset.metadata[key]
            if type(val) is not list:
                summary[key] = val
        summary['event_type'] = df_type
        summaries.append(summary)
    return pd.concat(summaries)
コード例 #3
0
def plot_lick_triggered_average_for_container(container_id, save_figure=True):

    experiments = loading.get_filtered_ophys_experiment_table()
    container_data = experiments[experiments.container_id == container_id]

    figsize = (6, 4)
    fig, ax = plt.subplots(figsize=figsize)

    for session_number in container_data.session_number.unique():
        experiment_ids = container_data[container_data.session_number ==
                                        session_number].index.values
        for experiment_id in experiment_ids:
            dataset = loading.get_ophys_dataset(experiment_id)
            analysis = ResponseAnalysis(
                dataset,
                use_events=True,
                use_extended_stimulus_presentations=False)
            ldf = analysis.get_response_df(
                df_name='lick_triggered_response_df')
            if len(ldf.cell_specimen_id.unique()) > 5:
                colors = utils.get_colors_for_session_numbers()
                ax = plot_lick_triggered_average(
                    ldf,
                    color=colors[session_number - 1],
                    ylabel='pop. avg. \nevent magnitude',
                    legend_label=session_number,
                    ax=ax)
    ax.legend(loc='upper left', fontsize='x-small')
    fig.tight_layout()
    m = dataset.metadata.copy()
    title = str(m['mouse_id']) + '_' + m['full_genotype'].split(
        '/')[0] + '_' + m['targeted_structure'] + '_' + str(
            m['imaging_depth']) + '_' + m['equipment_name'] + '_' + str(
                m['experiment_container_id'])
    fig.suptitle(title, x=0.53, y=1.02, fontsize=16)
    if save_figure:
        save_dir = loading.get_container_plots_dir()
        utils.save_figure(fig, figsize, save_dir, 'lick_triggered_average',
                          title)
コード例 #4
0
def generate_save_plots(experiment_id, split_by):
    dataset = loading.get_ophys_dataset(experiment_id)
    analysis = ResponseAnalysis(dataset,
                                overwrite_analysis_files=False,
                                dataframe_format='tidy',
                                use_extended_stimulus_presentations=True)
    for cell_specimen_id in dataset.cell_specimen_table.query(
            'valid_roi==True').index.values:
        sf.make_cell_response_summary_plot(analysis,
                                           cell_specimen_id,
                                           split_by,
                                           save=True,
                                           show=False,
                                           errorbar_bootstrap_iterations=1000)
コード例 #5
0
def create_analysis_files(experiment_id, cache_dir, overwrite_analysis_files=True):
    use_events = False
    print(experiment_id)
    dataset = loading.get_ophys_dataset(experiment_id)
    analysis_dir = dataset.analysis_dir
    if len(analysis_dir) == 0:
        try:
            _ = convert_level_1_to_level_2(experiment_id, cache_dir, plot_roi_validation=False)
        except:  # NOQA E722
            print('could not convert', experiment_id)
    analysis = ResponseAnalysis(dataset, use_events=use_events, overwrite_analysis_files=overwrite_analysis_files)

    _ = analysis.trials_response_df
    _ = analysis.omission_response_df
    _ = analysis.stimulus_response_df
コード例 #6
0
len(all_experiment_ids)

# Create a dictionary with all of the data you want to use (based on the conditions above)
all_data = {}
for n, data in enumerate(all_experiment_ids):
    all_data['%s' % data] = VisualBehaviorOphysDataset(all_experiment_ids[n], cache_dir=drive_path)
len(all_data) # to confirm  length of all_data is equal to that of all_experiment_ids


#Select for a specific cell and image
# temp0 = all_data[experiment_id]
expID = 639438856 #your experiment of choice
cellID = 19       #your cell of choice
temp0 = all_data['639438856'] #create a temporary file with data for your experiment_id

analysis = ResponseAnalysis(temp0) #fetching data, refer to beginning
data = analysis.get_flash_response_df() #data for all cells

#select cell
data_cell = data[(data.cell==cellID)]
data_cell.reset_index(drop=True, inplace=True)
data_cell


#add information about image repeats
len(data_cell)
addCol = np.zeros(len(data_cell))
addCol[0]=1

counter = 1
コード例 #7
0
                                    image_name].index.values:
            if stimulus_table.iloc[index]['repeat'] == 1:
                block += 1
            stimulus_table.loc[index, 'image_block'] = int(block)
    return stimulus_table


def add_image_block_to_flash_response_df(flash_response_df, stimulus_table):
    stimulus_table = add_image_block_to_stimulus_table(stimulus_table)
    flash_response_df = flash_response_df.merge(
        stimulus_table[['flash_number', 'image_block']], on='flash_number')
    return flash_response_df


if __name__ == '__main__':
    from visual_behavior.ophys.dataset.visual_behavior_ophys_dataset import VisualBehaviorOphysDataset
    from visual_behavior.ophys.response_analysis.response_analysis import ResponseAnalysis

    #change this for AWS or hard drive
    cache_dir = r'\\allen\programs\braintv\workgroups\nc-ophys\visual_behavior\visual_behavior_pilot_analysis'
    experiment_id = 672185644
    dataset = VisualBehaviorOphysDataset(experiment_id, cache_dir=cache_dir)
    analysis = ResponseAnalysis(dataset)

    stimulus_table = analysis.dataset.stimulus_table.copy()
    flash_response_df = analysis.flash_response_df.copy()
    flash_response_df = add_repeat_number_to_flash_response_df(
        flash_response_df, stimulus_table)
    flash_response_df = add_image_block_to_flash_response_df(
        flash_response_df, stimulus_table)
コード例 #8
0
ファイル: data_explorer.py プロジェクト: yonibrowning/oBehave
def data_explorer(expt_id, begin, end):
    '''
    Implemenation of a tool for exploring a visual behavior dataset.
    Plots pretty much all of the info from a given experiment in an easy to read format.
    
    Inputs:
    expt_id: experiment number to look up. This script will read the AWS manifest for you.
    begin: where to start plotting
    end: where to stop plotting.
    
    Created by Deepa Ramamurthy, August 2018
    '''
    # AWS1
    drive_path = '/data/dynamic-brain-workshop/visual_behavior'

    manifest_file = 'visual_behavior_data_manifest.csv'

    manifest = pd.read_csv(os.path.join(drive_path, manifest_file))

    #
    experiment_id = expt_id

    # import visual behavior dataset class from the visual_behavior package
    from visual_behavior.ophys.dataset.visual_behavior_ophys_dataset import VisualBehaviorOphysDataset

    dataset = VisualBehaviorOphysDataset(experiment_id, cache_dir=drive_path)

    # attribute method of accesing data
    dff_traces = dataset.dff_traces
    timestamps_ophys = dataset.timestamps_ophys

    from visual_behavior.ophys.response_analysis.response_analysis import ResponseAnalysis
    analysis = ResponseAnalysis(dataset)

    event_drive_path = '/data/dynamic-brain-workshop/visual_behavior_events'  #AWS
    tmp = os.path.join(event_drive_path, str(expt_id) + '_events.npz')
    tmp_data = np.load(tmp)
    event_array = tmp_data['ev']

    # Figure setup
    fig, ax = plt.subplots(1, 1, figsize=(16, 14))

    # Make raster plot
    for i, tr_spikes in enumerate(event_array):
        ax.plot(dataset.timestamps_ophys, dataset.dff_traces[i] + i)
        ax.plot(dataset.timestamps_ophys, event_array[i] + i, '.', color='k')

    plt.plot(dataset.timestamps_stimulus,
             (dataset.running_speed.running_speed.values * 0.05) - 8)

    plt.xlim(begin, end)
    #plt.xlim(0,max(dataset.timestamps_ophys))

    # plot rewards
    reward_y_vals = np.repeat(-1, repeats=len(dataset.rewards.time.values))
    plt.plot(
        dataset.rewards.time.values,
        reward_y_vals,
        marker='*',
        linestyle='None',
        label='rewards',
    )

    # # plot licks
    lick_y_vals = np.repeat(-10, repeats=len(dataset.licks.time.values))
    plt.vlines(dataset.licks.values, -2, -5, label='licks', color='r')

    for index in dataset.stimulus_table.index:
        row_data = dataset.stimulus_table.iloc[index]
        plt.axvspan(xmin=row_data.start_time,
                    xmax=row_data.end_time,
                    facecolor='gray',
                    alpha=0.3)

    ax.set_ylabel('Cells')
    ax.set_xlabel('time')
    #set ticks
    ax.set_yticklabels([], visible=False)

    plt.show()
コード例 #9
0
def plot_across_session_responses(ophys_container_id, cell_specimen_id, use_events=False, save_figure=True):
    """
    Generates plots characterizing single cell activity in response to stimulus, omissions, and changes.
    Compares across all sessions in a container for each cell, including the ROI mask across days.
    Useful to validate cell matching as well as examine changes in activity profiles over days.
    """
    experiments_table = data_loading.get_filtered_ophys_experiment_table(release_data_only=True)
    container_expts = experiments_table[experiments_table.ophys_container_id == ophys_container_id]
    expts = np.sort(container_expts.index.values)
    if use_events:
        ylabel = 'response'
        suffix = '_events'
    else:
        ylabel = 'dF/F'
        suffix = ''

    n = len(expts)
    figsize = (25, 20)
    fig, ax = plt.subplots(6, n, figsize=figsize)
    ax = ax.ravel()
    print('ophys_container_id:', ophys_container_id)
    for i, ophys_experiment_id in enumerate(expts):
        print('ophys_experiment_id:', ophys_experiment_id)
        try:

            dataset = data_loading.get_ophys_dataset(ophys_experiment_id, include_invalid_rois=False)
            if cell_specimen_id in dataset.dff_traces.index:
                analysis = ResponseAnalysis(dataset, use_events=use_events, use_extended_stimulus_presentations=False)
                sdf = ut.get_mean_df(analysis.get_response_df(df_name='stimulus_response_df'), analysis=analysis,
                                     conditions=['cell_specimen_id', 'is_change', 'image_name'], flashes=True, omitted=False,
                                     get_reliability=False, get_pref_stim=True, exclude_omitted_from_pref_stim=True)
                odf = ut.get_mean_df(analysis.get_response_df(df_name='omission_response_df'), analysis=analysis,
                                     conditions=['cell_specimen_id'], flashes=False, omitted=True,
                                     get_reliability=False, get_pref_stim=False, exclude_omitted_from_pref_stim=False)
                tdf = ut.get_mean_df(analysis.get_response_df(df_name='trials_response_df'), analysis=analysis,
                                     conditions=['cell_specimen_id', 'go', 'hit', 'change_image_name'], flashes=False, omitted=False,
                                     get_reliability=False, get_pref_stim=True, exclude_omitted_from_pref_stim=True)

                ct = dataset.cell_specimen_table.copy()
                cell_roi_id = ct.loc[cell_specimen_id].cell_roi_id
                ax[i] = sf.plot_cell_zoom(dataset.roi_masks, dataset.max_projection, cell_roi_id,
                                          spacex=20, spacey=20, show_mask=True, ax=ax[i])
                ax[i].set_title(container_expts.loc[ophys_experiment_id].session_type[6:])

                colors = sns.color_palette('hls', 8) + [(0.5, 0.5, 0.5)]

                window = rp.get_default_stimulus_response_params()["window_around_timepoint_seconds"]
                cell_data = sdf[(sdf.cell_specimen_id == cell_specimen_id) & (sdf.is_change == False)]
                for c, image_name in enumerate(np.sort(cell_data.image_name.unique())):
                    ax[i + n] = sf.plot_mean_trace_from_mean_df(cell_data[cell_data.image_name == image_name],
                                                                frame_rate=analysis.ophys_frame_rate, ylabel=ylabel,
                                                                legend_label=image_name, color=colors[c], interval_sec=0.5,
                                                                xlims=window, ax=ax[i + n])
                ax[i + n] = sf.plot_flashes_on_trace(ax[i + n], analysis, window=window, trial_type=None, omitted=False, alpha=0.15, facecolor='gray')
                ax[i + n].set_title(container_expts.loc[ophys_experiment_id].session_type[6:] + '\n image response')

                analysis = ResponseAnalysis(dataset, use_events=False, use_extended_stimulus_presentations=True)
                tmp = analysis.get_response_df(df_name='stimulus_response_df')
                tmp['running'] = [True if run_speed > 2 else False for run_speed in tmp.mean_running_speed.values]
                sdf = ut.get_mean_df(tmp, analysis=analysis,
                                     conditions=['cell_specimen_id', 'is_change', 'image_name', 'running'], flashes=True, omitted=False,
                                     get_reliability=False, get_pref_stim=True, exclude_omitted_from_pref_stim=False)

                cell_data = sdf[(sdf.cell_specimen_id == cell_specimen_id) & (sdf.is_change == False) & (sdf.pref_stim == True)]
                run_colors = [sns.color_palette()[3], sns.color_palette()[2]]
                for c, running in enumerate(np.sort(cell_data.running.unique())):
                    if len(cell_data[cell_data.running == running]) > 0:
                        ax[i + (n * 2)] = sf.plot_mean_trace_from_mean_df(cell_data[cell_data.running == running],
                                                                          frame_rate=analysis.ophys_frame_rate, ylabel=ylabel,
                                                                          legend_label=running, color=run_colors[c], interval_sec=0.5,
                                                                          xlims=window, ax=ax[i + (n * 2)])
                ax[i + (n * 2)].legend(fontsize='xx-small', title='running', title_fontsize='xx-small')
                ax[i + (n * 2)] = sf.plot_flashes_on_trace(ax[i + (n * 2)], analysis, window=window, trial_type=None, omitted=False, alpha=0.15, facecolor='gray')
                ax[i + (n * 2)].set_title(container_expts.loc[ophys_experiment_id].session_type[6:] + '\n image response')

                window = rp.get_default_omission_response_params()["window_around_timepoint_seconds"]
                cell_data = odf[(odf.cell_specimen_id == cell_specimen_id)]
                ax[i + (n * 3)] = sf.plot_mean_trace_from_mean_df(cell_data,
                                                                  frame_rate=analysis.ophys_frame_rate, ylabel=ylabel,
                                                                  legend_label=image_name, color='gray', interval_sec=1,
                                                                  xlims=window, ax=ax[i + (n * 3)])
                ax[i + (n * 3)] = sf.plot_flashes_on_trace(ax[i + (n * 3)], analysis, window=window, trial_type=None, omitted=True, alpha=0.15, facecolor='gray')
                ax[i + (n * 3)].set_title(container_expts.loc[ophys_experiment_id].session_type[6:] + '\n omission response')

                window = rp.get_default_trial_response_params()["window_around_timepoint_seconds"]
                cell_data = tdf[(tdf.cell_specimen_id == cell_specimen_id) & (tdf.go == True) & (tdf.pref_stim == True)]
                hit_colors = [sns.color_palette()[2], sns.color_palette()[3]]
                for c, hit in enumerate([True, False]):
                    if len(cell_data[cell_data.hit == hit]) > 0:
                        ax[i + (n * 4)] = sf.plot_mean_trace_from_mean_df(cell_data[cell_data.hit == hit],
                                                                          frame_rate=analysis.ophys_frame_rate, ylabel=ylabel,
                                                                          legend_label=hit, color=hit_colors[c], interval_sec=1,
                                                                          xlims=window, ax=ax[i + (n * 4)])
                ax[i + (n * 4)].legend(fontsize='xx-small', title='hit', title_fontsize='xx-small')
                ax[i + (n * 4)] = sf.plot_flashes_on_trace(ax[i + (n * 4)], analysis, window=window, trial_type='go', omitted=False, alpha=0.15, facecolor='gray')
                ax[i + (n * 4)].set_title(container_expts.loc[ophys_experiment_id].session_type[6:] + '\n change response')

                fig.tight_layout()
                fig.suptitle(str(cell_specimen_id) + '_' + dataset.metadata_string, x=0.5, y=1.01,
                             horizontalalignment='center')
        except Exception as e:
            print('problem for cell_specimen_id:', cell_specimen_id, ', ophys_experiment_id:', ophys_experiment_id)
            print(e)
    if save_figure:
        save_dir = utils.get_single_cell_plots_dir()
        utils.save_figure(fig, figsize, save_dir, 'across_session_responses', str(
            cell_specimen_id) + '_' + dataset.metadata_string + '_across_session_responses' + suffix)
        plt.close()
def get_multi_session_mean_df(experiment_ids,
                              analysis_cache_dir,
                              df_name,
                              conditions=[
                                  'cell_specimen_id', 'change_image_name',
                                  'behavioral_response_type'
                              ],
                              flashes=False,
                              use_events=False,
                              omitted=False,
                              get_reliability=False,
                              get_pref_stim=True,
                              exclude_omitted_from_pref_stim=True,
                              use_sdk_dataset=True):
    experiments_table = loading.get_filtered_ophys_experiment_table()
    mega_mdf = pd.DataFrame()
    for experiment_id in experiment_ids:
        print(experiment_id)
        # try:
        if use_sdk_dataset:
            dataset = loading.get_ophys_dataset(experiment_id)
        else:
            dataset = VisualBehaviorOphysDataset(experiment_id,
                                                 analysis_cache_dir)
        analysis = ResponseAnalysis(dataset,
                                    use_events=use_events,
                                    overwrite_analysis_files=False,
                                    use_extended_stimulus_presentations=True)
        df = analysis.get_response_df(df_name)
        df['ophys_experiment_id'] = dataset.ophys_experiment_id
        df['project_code'] = experiments_table.loc[experiment_id].project_code
        df['session_type'] = experiments_table.loc[experiment_id].session_type
        # if 'engaged' in conditions:
        #     df['engaged'] = [True if reward_rate > 2 else False for reward_rate in df.reward_rate.values]
        if 'running' in conditions:
            df['running'] = [
                True if window_running_speed > 5 else False
                for window_running_speed in df.window_running_speed.values
            ]
        # if 'large_pupil' in conditions:
        #     if 'mean_pupil_area' in df.keys():
        #         df = df[df.mean_pupil_area.isnull() == False]
        #         if len(df) > 100:
        #             median_pupil_area = df.mean_pupil_area.median()
        #             df['large_pupil'] = [True if mean_pupil_area > median_pupil_area else False for mean_pupil_area in
        #                                  df.mean_pupil_area.values]
        mdf = ut.get_mean_df(
            df,
            analysis,
            conditions=conditions,
            get_pref_stim=get_pref_stim,
            flashes=flashes,
            omitted=omitted,
            get_reliability=get_reliability,
            exclude_omitted_from_pref_stim=exclude_omitted_from_pref_stim)
        mdf['ophys_experiment_id'] = dataset.ophys_experiment_id
        dataset.metadata['reporter_line'] = dataset.metadata['reporter_line'][
            0]
        dataset.metadata['driver_line'] = dataset.metadata['driver_line'][0]
        metadata = pd.DataFrame(dataset.metadata, index=[experiment_id])
        mdf = ut.add_metadata_to_mean_df(mdf, metadata)
        mega_mdf = pd.concat([mega_mdf, mdf])
        # except Exception as e:  # flake8: noqa: E722
        #     print(e)
        #     print('problem for', experiment_id)
    if use_events:
        suffix = '_events'
    else:
        suffix = ''
    if 'level_0' in mega_mdf.keys():
        mega_mdf = mega_mdf.drop(columns='level_0')
    if 'index' in mega_mdf.keys():
        mega_mdf = mega_mdf.drop(columns='index')

    mega_mdf_write_dir = os.path.join(analysis_cache_dir,
                                      'multi_session_summary_dfs')
    if not os.path.exists(mega_mdf_write_dir):
        os.makedirs(mega_mdf_write_dir)

    if len(conditions) == 5:
        filename = 'mean_' + df_name + '_' + conditions[1] + '_' + conditions[2] + '_' + conditions[3] + '_' + \
                   conditions[4] + suffix + '.h5'
    elif len(conditions) == 4:
        filename = 'mean_' + df_name + '_' + conditions[1] + '_' + conditions[
            2] + '_' + conditions[3] + suffix + '.h5'
    elif len(conditions) == 3:
        filename = 'mean_' + df_name + '_' + conditions[1] + '_' + conditions[
            2] + suffix + '.h5'
    elif len(conditions) == 2:
        filename = 'mean_' + df_name + '_' + conditions[1] + suffix + '.h5'
    elif len(conditions) == 1:
        filename = 'mean_' + df_name + '_' + conditions[0] + suffix + '.h5'

    print('saving multi session mean df to ', filename)
    mega_mdf.to_hdf(os.path.join(mega_mdf_write_dir, filename), key='df')
    print('saved')
    return mega_mdf
コード例 #11
0
def plot_matched_roi_and_trace(ophys_container_id, cell_specimen_id, limit_to_last_familiar_second_novel=True,
                               use_events=False, filter_events=False, save_figure=True):
    """
    Generates plots characterizing single cell activity in response to stimulus, omissions, and changes.
    Compares across all sessions in a container for each cell, including the ROI mask across days.
    Useful to validate cell matching as well as examine changes in activity profiles over days.
    """
    experiments_table = loading.get_platform_paper_experiment_table()
    if limit_to_last_familiar_second_novel:  # this ensures only one session per experience level
        experiments_table = utilities.limit_to_last_familiar_second_novel_active(experiments_table)
        experiments_table = utilities.limit_to_containers_with_all_experience_levels(experiments_table)

    container_expts = experiments_table[experiments_table.ophys_container_id == ophys_container_id]
    container_expts = container_expts.sort_values(by=['experience_level'])
    expts = np.sort(container_expts.index.values)

    if use_events:
        if filter_events:
            suffix = 'filtered_events'
        else:
            suffix = 'events'
        ylabel = 'response'
    else:
        suffix = 'dff'
        ylabel = 'dF/F'

    n = len(expts)
    if limit_to_last_familiar_second_novel:
        figsize = (9, 6)
        folder = 'matched_cells_exp_levels'
    else:
        figsize = (20, 6)
        folder = 'matched_cells_all_sessions'
    fig, ax = plt.subplots(2, n, figsize=figsize, sharey='row')
    ax = ax.ravel()
    print('ophys_container_id:', ophys_container_id)
    for i, ophys_experiment_id in enumerate(expts):
        print('ophys_experiment_id:', ophys_experiment_id)
        try:
            dataset = loading.get_ophys_dataset(ophys_experiment_id, get_extended_stimulus_presentations=False)
            if cell_specimen_id in dataset.dff_traces.index:

                ct = dataset.cell_specimen_table.copy()
                cell_roi_id = ct.loc[cell_specimen_id].cell_roi_id
                roi_masks = dataset.roi_masks.copy()  # save this to get approx ROI position if subsequent session is missing the ROI (fails if the first session is the one missing the ROI)
                ax[i] = sf.plot_cell_zoom(dataset.roi_masks, dataset.max_projection, cell_roi_id,
                                          spacex=50, spacey=50, show_mask=True, ax=ax[i])
                ax[i].set_title(container_expts.loc[ophys_experiment_id].experience_level)

                analysis = ResponseAnalysis(dataset, use_events=use_events, filter_events=filter_events,
                                            use_extended_stimulus_presentations=False)
                sdf = analysis.get_response_df(df_name='stimulus_response_df')
                cell_data = sdf[(sdf.cell_specimen_id == cell_specimen_id) & (sdf.is_change == True)]

                window = rp.get_default_stimulus_response_params()["window_around_timepoint_seconds"]
                ax[i + n] = utils.plot_mean_trace(cell_data.trace.values, cell_data.trace_timestamps.values[0],
                                                  ylabel=ylabel, legend_label=None, color='gray', interval_sec=0.5,
                                                  xlim_seconds=window, plot_sem=True, ax=ax[i + n])

                ax[i + n] = utils.plot_flashes_on_trace(ax[i + n], cell_data.trace_timestamps.values[0], change=True, omitted=False,
                                                        alpha=0.15, facecolor='gray')
                ax[i + n].set_title('')
                if i != 0:
                    ax[i + n].set_ylabel('')
            else:
                # plot the max projection image with the xy location of the previous ROI
                # this will fail if the familiar session is the one without the cell matched
                ax[i] = sf.plot_cell_zoom(roi_masks, dataset.max_projection, cell_roi_id,
                                          spacex=50, spacey=50, show_mask=False, ax=ax[i])
                ax[i].set_title(container_expts.loc[ophys_experiment_id].experience_level)

            metadata_string = utils.get_metadata_string(dataset.metadata)

            fig.tight_layout()
            fig.suptitle(str(cell_specimen_id) + '_' + metadata_string, x=0.53, y=1.02,
                         horizontalalignment='center', fontsize=16)
        except Exception as e:
            print('problem for cell_specimen_id:', cell_specimen_id, ', ophys_experiment_id:', ophys_experiment_id)
            print(e)
    if save_figure:
        save_dir = r'//allen/programs/braintv/workgroups/nc-ophys/visual_behavior/platform_paper_plots/cell_matching'
        utils.save_figure(fig, figsize, save_dir, folder, str(cell_specimen_id) + '_' + metadata_string + '_' + suffix)
        plt.close()
def plot_notebook_figures(experiment_id, save_dir):
    print(experiment_id)
    figsize = (20, 15)
    fig, ax = plt.subplots(6, 3, figsize=figsize)
    ax = ax.ravel()
    x = 0

    cache_dir = save_dir

    from visual_behavior.ophys.dataset.visual_behavior_ophys_dataset import VisualBehaviorOphysDataset
    dataset = VisualBehaviorOphysDataset(experiment_id, cache_dir=cache_dir)

    ax[x].imshow(
        dataset.max_projection,
        cmap='gray',
    )
    ax[x].axis('off')
    x += 1

    cell_index = 8
    ax[x].plot(dataset.ophys_timestamps, dataset.dff_traces[cell_index])
    ax[x].set_xlabel('seconds')
    ax[x].set_ylabel('dF/F')
    x += 1

    cell_specimen_id = dataset.get_cell_specimen_id_for_cell_index(cell_index)
    ax[x].imshow(dataset.roi_mask_dict[str(cell_specimen_id)])
    ax[x].grid('off')
    x += 1

    ax[x].plot(dataset.stimulus_timestamps,
               dataset.running_speed.running_speed.values)
    ax[x].set_xlabel('time (sec)')
    ax[x].set_ylabel('running speed (cm/s)')
    x += 1

    ax[x].plot(dataset.stimulus_timestamps,
               dataset.running_speed.running_speed.values)
    ax[x].set_xlim(600, 660)
    # plot licks
    lick_y_vals = np.repeat(-10, repeats=len(dataset.licks.time.values))
    ax[x].plot(dataset.licks.time.values, lick_y_vals, '.', label='licks')
    # plot rewards
    reward_y_vals = np.repeat(-10, repeats=len(dataset.rewards.time.values))
    ax[x].plot(dataset.rewards.time.values,
               reward_y_vals,
               'o',
               label='rewards')
    ax[x].set_xlabel('time (sec)')
    ax[x].set_ylabel('running speed (cm/s)')
    ax[x].legend(loc=9, bbox_to_anchor=(1.2, 1))
    x += 1

    # plot running
    ax[x].plot(dataset.stimulus_timestamps,
               dataset.running_speed.running_speed.values)
    ax[x].set_xlim(600, 660)
    # plot licks
    lick_y_vals = np.repeat(-10, repeats=len(dataset.licks.time.values))
    ax[x].plot(dataset.licks.time.values, lick_y_vals, '.', label='licks')
    # plot rewards
    reward_y_vals = np.repeat(-10, repeats=len(dataset.rewards.time.values))
    ax[x].plot(dataset.rewards.time.values,
               reward_y_vals,
               'o',
               label='rewards')
    for flash_number in dataset.stimulus_table.flash_number.values:
        row_data = dataset.stimulus_table[dataset.stimulus_table.flash_number
                                          == flash_number]
        ax[x].axvspan(xmin=row_data.start_time.values[0],
                      xmax=row_data.end_time.values[0],
                      facecolor='gray',
                      alpha=0.3)
    ax[x].set_xlim(600, 620)
    ax[x].set_xlabel('time (sec)')
    ax[x].set_ylabel('running speed (cm/s)')
    ax[x].legend(loc=9, bbox_to_anchor=(1.2, 1))
    x += 1

    stimulus_metadata = dataset.stimulus_metadata
    stimulus_template = dataset.stimulus_template

    image_index = 2
    ax[x].imshow(stimulus_template[image_index], cmap='gray')
    image_name = stimulus_metadata[stimulus_metadata.image_index ==
                                   image_index].image_name.values[0]
    ax[x].set_title(image_name)
    x += 1

    trials = dataset.trials
    images = trials.change_image_name.unique()

    trial_type = 'go'
    for i, image in enumerate(images):
        selected_trials = trials[(trials.change_image_name == image)
                                 & (trials.trial_type == trial_type)]
        response_probability = selected_trials.response.mean()
        ax[i].plot(i, response_probability, 'o')
    ax[x].set_xticks(np.arange(0, len(images), 1))
    ax[x].set_xticklabels(images, rotation=90)
    ax[x].set_ylabel('response probability')
    ax[x].set_xlabel('change image')
    ax[x].set_title('go trials')
    ax[x].set_ylim(0, 1)
    x += 1

    trial_type = 'catch'
    for i, image in enumerate(images):
        selected_trials = trials[(trials.change_image_name == image)
                                 & (trials.trial_type == trial_type)]
        response_probability = selected_trials.response.mean()
        ax[x].plot(i, response_probability, 'o')
    ax[x].set_xticks(np.arange(0, len(images), 1))
    ax[x].set_xticklabels(images, rotation=90)
    ax[x].set_ylabel('response probability')
    ax[x].set_xlabel('change image')
    ax[x].set_title('catch trials')
    ax[x].set_ylim(0, 1)
    x += 1

    colors = sns.color_palette()
    trial_types = trials.trial_type.unique()
    for i, image in enumerate(images):
        for t, trial_type in enumerate(trial_types):
            selected_trials = trials[(trials.change_image_name == image)
                                     & (trials.trial_type == trial_type)]
            response_probability = selected_trials.response.mean()
            ax[x].plot(i, response_probability, 'o', color=colors[t])
    ax[x].set_ylim(0, 1)
    ax[x].set_xticks(np.arange(0, len(images), 1))
    ax[x].set_xticklabels(images, rotation=90)
    ax[x].set_ylabel('response probability')
    ax[x].set_xlabel('change image')
    ax[x].set_title('response probability by trial type & image')
    ax[x].legend(['go', 'catch'])
    x += 1

    def make_lick_raster(trials, ax):
        fig, ax = plt.subplots(figsize=(5, 10))
        for trial in trials.trial.values:
            trial_data = trials.iloc[trial]
            # get times relative to change time
            trial_start = trial_data.start_time - trial_data.change_time
            lick_times = [(t - trial_data.change_time)
                          for t in trial_data.lick_times]
            reward_time = [(t - trial_data.change_time)
                           for t in trial_data.reward_times]
            # plot trials as colored rows
            ax.axhspan(trial,
                       trial + 1,
                       -200,
                       200,
                       color=trial_data.trial_type_color,
                       alpha=.5)
            # plot reward times
            if len(reward_time) > 0:
                ax.plot(reward_time[0],
                        trial + 0.5,
                        '.',
                        color='b',
                        label='reward',
                        markersize=6)
            ax.vlines(trial_start,
                      trial,
                      trial + 1,
                      color='black',
                      linewidth=1)
            # plot lick times
            ax.vlines(lick_times, trial, trial + 1, color='k', linewidth=1)
            # annotate change time
            ax.vlines(0, trial, trial + 1, color=[.5, .5, .5], linewidth=1)
        # gray bar for response window
        ax.axvspan(trial_data.response_window[0],
                   trial_data.response_window[1],
                   facecolor='gray',
                   alpha=.4,
                   edgecolor='none')
        ax.grid(False)
        ax.set_ylim(0, len(trials))
        ax.set_xlim([-1, 4])
        ax.set_ylabel('trials')
        ax.set_xlabel('time (sec)')
        ax.set_title('lick raster')
        plt.gca().invert_yaxis()

    make_lick_raster(dataset.trials, ax=ax[x])
    x += 1

    from visual_behavior.ophys.response_analysis.response_analysis import ResponseAnalysis
    analysis = ResponseAnalysis(dataset)

    tdf = analysis.trial_response_df

    largest_response = tdf[tdf.mean_response == tdf.mean_response.max()]
    cell = largest_response.cell.values[0]
    image_name = largest_response.change_image_name.values[0]
    trial_type = largest_response.trial_type.values[0]

    import visual_behavior.visualization.ophys.summary_figures as sf

    trace = largest_response.trace.values[0]
    frame_rate = analysis.ophys_frame_rate
    ax[x] = sf.plot_single_trial_trace(trace,
                                       frame_rate,
                                       ylabel='dF/F',
                                       legend_label=None,
                                       color='k',
                                       interval_sec=1,
                                       xlims=[-4, 4],
                                       ax=ax[x])
    x += 1

    traces = tdf[(tdf.cell == cell) & (tdf.trial_type == trial_type) &
                 (tdf.change_image_name == image_name)].trace
    traces = traces.values
    ax[x] = sf.plot_mean_trace(traces,
                               analysis.ophys_frame_rate,
                               ylabel='dF/F',
                               legend_label='go',
                               color='k',
                               interval_sec=1,
                               xlims=[-4, 4],
                               ax=ax[x])
    ax[x] = sf.plot_flashes_on_trace(ax[x],
                                     analysis,
                                     trial_type='go',
                                     omitted=False,
                                     alpha=0.4)
    x += 1

    traces = tdf[(tdf.cell == cell) & (tdf.trial_type == 'go') &
                 (tdf.change_image_name == image_name)].trace
    traces = np.asarray(traces)
    ax[x] = sf.plot_mean_trace(traces,
                               analysis.ophys_frame_rate,
                               ylabel='dF/F',
                               legend_label='go',
                               color='k',
                               interval_sec=1,
                               xlims=[-4, 4],
                               ax=ax[x])
    ax[x] = sf.plot_flashes_on_trace(ax[x],
                                     analysis,
                                     trial_type='go',
                                     omitted=False,
                                     alpha=0.4)
    x += 1

    traces = tdf[(tdf.cell == cell) & (tdf.trial_type == 'catch') &
                 (tdf.change_image_name == image_name)].trace
    traces = np.asarray(traces)
    ax[x] = sf.plot_mean_trace(traces,
                               analysis.ophys_frame_rate,
                               ylabel='dF/F',
                               legend_label=None,
                               color='k',
                               interval_sec=1,
                               xlims=[-4, 4],
                               ax=ax[x])
    ax[x] = sf.plot_flashes_on_trace(ax[x],
                                     analysis,
                                     trial_type=None,
                                     omitted=False,
                                     alpha=0.4)
    x += 1

    fdf = analysis.flash_response_df

    from scipy.stats import sem

    images = np.sort(fdf.image_name.unique())
    for i, image_name in enumerate(images):
        responses = fdf[(fdf.cell == cell)
                        & (fdf.image_name == image_name)].mean_response.values
        mean_response = np.mean(responses)
        std_err = sem(responses)
        ax[x].plot(i, mean_response, 'o', color='k')
        ax[x].errorbar(i, mean_response, yerr=std_err, color='k')
    ax[x].set_xticks(np.arange(0, len(images), 1))
    ax[x].set_xticklabels(images, rotation=90)
    ax[x].set_ylabel('mean dF/F')
    ax[x].set_xlabel('image')
    ax[x].set_title('image tuning curve - all flashes')
    x += 1

    df = analysis.trial_response_df[(
        analysis.trial_response_df.trial_type == 'go')]

    images = np.sort(df.change_image_name.unique())
    image_responses = []
    sem_responses = []
    for i, change_image_name in enumerate(images):
        responses = df[(df.change_image_name == change_image_name)
                       & (df.cell == cell)].mean_response.values
        mean_response = np.mean(responses)
        sem_response = sem(responses)
        image_responses.append(mean_response)
        sem_responses.append(sem_response)
    image_responses = np.asarray(image_responses)
    sem_responses = np.asarray(sem_responses)

    x_vals = np.arange(0, len(images), 1)
    ax[x].plot(x_vals, image_responses, 'o', color='k')
    ax[x].errorbar(x_vals, image_responses, yerr=sem_responses, color='k')
    ax[x].set_xticks(np.arange(0, len(images), 1))
    ax[x].set_xticklabels(images, rotation=90)
    ax[x].set_ylabel('mean dF/F')
    ax[x].set_ylabel('image')
    ax[x].set_title('image tuning curve - go trials')
    x += 1

    def compute_lifetime_sparseness(image_responses):
        # image responses should be an array of the trial averaged responses to each image
        # sparseness = 1-(sum of trial averaged responses to images / N)squared / (sum of (squared mean responses / n)) / (1-(1/N))
        # N = number of images
        # after Vinje & Gallant, 2000; Froudarakis et al., 2014
        N = float(len(image_responses))
        ls = ((1 - (1 / N) * ((np.power(image_responses.sum(axis=0), 2)) /
                              (np.power(image_responses, 2).sum(axis=0)))) /
              (1 - (1 / N)))
        return ls

    responsive_cells = []
    for cell in df.cell.unique():
        cell_data = df[(df.cell == cell)]
        total_trials = len(cell_data)
        responsive_trials = len(cell_data[cell_data.p_value < 0.005])
        fraction_responsive_trials = responsive_trials / float(total_trials)
        if fraction_responsive_trials > 0.1:
            responsive_cells.append(cell)
    print(
        'fraction responsive cells = ',
        len(responsive_cells) /
        float(len(analysis.trial_response_df.cell.unique())))

    images = np.sort(df.change_image_name.unique())
    lifetime_sparseness_values = []
    for cell in responsive_cells:
        image_responses = []
        for i, change_image_name in enumerate(images):
            responses = df[(df.cell == cell)
                           & (df.change_image_name == change_image_name
                              )].mean_response.values
            mean_response = np.mean(responses)
            std_err = sem(responses)
            image_responses.append(mean_response)
        ls = compute_lifetime_sparseness(np.asarray(image_responses))
        lifetime_sparseness_values.append(ls)
    lifetime_sparseness_values = np.asarray(lifetime_sparseness_values)
    mean_lifetime_sparseness = np.mean(lifetime_sparseness_values)
    print('mean lifetime sparseness for go trials:', mean_lifetime_sparseness)

    df = analysis.trial_response_df[(
        analysis.trial_response_df.trial_type == 'catch')]
    responsive_cells = []
    for cell in df.cell.unique():
        cell_data = df[(df.cell == cell)]
        total_trials = len(cell_data)
        responsive_trials = len(cell_data[cell_data.p_value < 0.005])
        fraction_responsive_trials = responsive_trials / float(total_trials)
        if fraction_responsive_trials > 0.1:
            responsive_cells.append(cell)

    images = np.sort(df.change_image_name.unique())
    lifetime_sparseness_values = []
    for cell in responsive_cells:
        image_responses = []
        for i, change_image_name in enumerate(images):
            responses = df[(df.cell == cell)
                           & (df.change_image_name == change_image_name
                              )].mean_response.values
            mean_response = np.mean(responses)
            std_err = sem(responses)
            image_responses.append(mean_response)
        ls = compute_lifetime_sparseness(np.asarray(image_responses))
        lifetime_sparseness_values.append(ls)
    lifetime_sparseness_values = np.asarray(lifetime_sparseness_values)
    mean_lifetime_sparseness = np.mean(lifetime_sparseness_values)

    print('mean lifetime sparseness for catch trials:',
          mean_lifetime_sparseness)

    plt.suptitle(str(experiment_id))
    fig.tight_layout()
    sf.save_figure(fig, figsize, save_dir, 'summary_figures',
                   str(experiment_id))
    print(x)
    plt.close()