def plot_average_flash_response_example_cells(analysis, save_figures=False, save_dir=None, folder=None, ax=None):
    import visual_behavior.ophys.response_analysis.utilities as ut
    fdf = analysis.stimulus_response_df
    last_flash = fdf.flash_number.unique()[-1]  # sometimes last flash is truncated
    fdf = fdf[fdf.flash_number != last_flash]

    conditions = ['cell_specimen_id', 'image_name']
    mdf = ut.get_mean_df(fdf, analysis, conditions=conditions, flashes=True)

    active_cell_indices = ut.get_active_cell_indices(analysis.dataset.dff_traces_array)
    random_order = np.arange(0, len(active_cell_indices), 1)
    np.random.shuffle(random_order)
    active_cell_indices = active_cell_indices[random_order]
    cell_specimen_ids = [analysis.dataset.get_cell_specimen_id_for_cell_index(cell_index) for cell_index in
                         active_cell_indices]

    image_names = np.sort(mdf.image_name.unique())

    if ax is None:
        figsize = (12, 10)
        fig, ax = plt.subplots(len(cell_specimen_ids), len(image_names), figsize=figsize, sharex=True)
        ax = ax.ravel()

    i = 0
    for c, cell_specimen_id in enumerate(cell_specimen_ids):
        cell_data = mdf[(mdf.cell_specimen_id == cell_specimen_id)]
        maxs = [np.amax(trace) for trace in cell_data.mean_trace.values]
        ymax = np.amax(maxs) * 1.2
        for m, image_name in enumerate(image_names):
            cdf = cell_data[(cell_data.image_name == image_name)]
            color = ut.get_color_for_image_name(image_names, image_name)
            #             ax[i] = psf.plot_mean_trace_from_mean_df(cdf, 31., color=sns.color_palette()[0], interval_sec=0.5,
            #                                                      xlims=analysis.flash_window, ax=ax[i])
            ax[i] = sf.plot_mean_trace_from_mean_df(cdf, analysis.ophys_frame_rate,
                                                    color=sns.color_palette()[0], interval_sec=0.5,
                                                    xlims=analysis.flash_window, ax=ax[i])
            ax[i] = sf.plot_flashes_on_trace(ax[i], analysis, flashes=True, facecolor=color, alpha=0.3)
            #             ax[i] = psf.plot_flashes_on_trace(ax[i], flashes=True, facecolor=color, window=analysis.flash_window, alpha=0.3)
            ax[i].vlines(x=-0.05, ymin=0, ymax=0.1, linewidth=3)
            #         sns.despine(ax=ax[i])
            ax[i].axis('off')
            ax[i].set_ylim(-0.05, ymax)
            if m == 0:
                ax[i].set_ylabel('x')
            if c == 0:
                ax[i].set_title(image_name)
            if c == len(cell_specimen_ids):
                ax[i].set_xlabel('time (s)')
            i += 1

    # fig.tight_layout()
    if save_figures:
        if save_dir:
            sf.save_figure(fig, figsize, save_dir, folder, analysis.dataset.analysis_folder)
        sf.save_figure(fig, figsize, analysis.dataset.analysis_dir, 'example_traces_all_flashes',
                       analysis.dataset.analysis_folder)
def reorder_traces(original_traces, analysis):
    tdf = analysis.trials_response_df
    df = ut.get_mean_df(tdf, analysis, conditions=['cell', 'change_image_name'])

    images = np.sort(df.change_image_name.unique())

    cell_list = []
    for image in images:
        tmp = df[(df.change_image_name == image) & (df.pref_stim == True)]
        order = np.argsort(tmp.mean_response.values)[::-1]
        cell_ids = list(tmp.cell.values[order])
        cell_list = cell_list + cell_ids

    reordered_traces = []
    for cell_index in cell_list:
        reordered_traces.append(original_traces[cell_index, :])
    return np.asarray(reordered_traces)
def plot_across_session_responses(ophys_container_id, cell_specimen_id, use_events=False, save_figure=True):
    """
    Generates plots characterizing single cell activity in response to stimulus, omissions, and changes.
    Compares across all sessions in a container for each cell, including the ROI mask across days.
    Useful to validate cell matching as well as examine changes in activity profiles over days.
    """
    experiments_table = data_loading.get_filtered_ophys_experiment_table(release_data_only=True)
    container_expts = experiments_table[experiments_table.ophys_container_id == ophys_container_id]
    expts = np.sort(container_expts.index.values)
    if use_events:
        ylabel = 'response'
        suffix = '_events'
    else:
        ylabel = 'dF/F'
        suffix = ''

    n = len(expts)
    figsize = (25, 20)
    fig, ax = plt.subplots(6, n, figsize=figsize)
    ax = ax.ravel()
    print('ophys_container_id:', ophys_container_id)
    for i, ophys_experiment_id in enumerate(expts):
        print('ophys_experiment_id:', ophys_experiment_id)
        try:

            dataset = data_loading.get_ophys_dataset(ophys_experiment_id, include_invalid_rois=False)
            if cell_specimen_id in dataset.dff_traces.index:
                analysis = ResponseAnalysis(dataset, use_events=use_events, use_extended_stimulus_presentations=False)
                sdf = ut.get_mean_df(analysis.get_response_df(df_name='stimulus_response_df'), analysis=analysis,
                                     conditions=['cell_specimen_id', 'is_change', 'image_name'], flashes=True, omitted=False,
                                     get_reliability=False, get_pref_stim=True, exclude_omitted_from_pref_stim=True)
                odf = ut.get_mean_df(analysis.get_response_df(df_name='omission_response_df'), analysis=analysis,
                                     conditions=['cell_specimen_id'], flashes=False, omitted=True,
                                     get_reliability=False, get_pref_stim=False, exclude_omitted_from_pref_stim=False)
                tdf = ut.get_mean_df(analysis.get_response_df(df_name='trials_response_df'), analysis=analysis,
                                     conditions=['cell_specimen_id', 'go', 'hit', 'change_image_name'], flashes=False, omitted=False,
                                     get_reliability=False, get_pref_stim=True, exclude_omitted_from_pref_stim=True)

                ct = dataset.cell_specimen_table.copy()
                cell_roi_id = ct.loc[cell_specimen_id].cell_roi_id
                ax[i] = sf.plot_cell_zoom(dataset.roi_masks, dataset.max_projection, cell_roi_id,
                                          spacex=20, spacey=20, show_mask=True, ax=ax[i])
                ax[i].set_title(container_expts.loc[ophys_experiment_id].session_type[6:])

                colors = sns.color_palette('hls', 8) + [(0.5, 0.5, 0.5)]

                window = rp.get_default_stimulus_response_params()["window_around_timepoint_seconds"]
                cell_data = sdf[(sdf.cell_specimen_id == cell_specimen_id) & (sdf.is_change == False)]
                for c, image_name in enumerate(np.sort(cell_data.image_name.unique())):
                    ax[i + n] = sf.plot_mean_trace_from_mean_df(cell_data[cell_data.image_name == image_name],
                                                                frame_rate=analysis.ophys_frame_rate, ylabel=ylabel,
                                                                legend_label=image_name, color=colors[c], interval_sec=0.5,
                                                                xlims=window, ax=ax[i + n])
                ax[i + n] = sf.plot_flashes_on_trace(ax[i + n], analysis, window=window, trial_type=None, omitted=False, alpha=0.15, facecolor='gray')
                ax[i + n].set_title(container_expts.loc[ophys_experiment_id].session_type[6:] + '\n image response')

                analysis = ResponseAnalysis(dataset, use_events=False, use_extended_stimulus_presentations=True)
                tmp = analysis.get_response_df(df_name='stimulus_response_df')
                tmp['running'] = [True if run_speed > 2 else False for run_speed in tmp.mean_running_speed.values]
                sdf = ut.get_mean_df(tmp, analysis=analysis,
                                     conditions=['cell_specimen_id', 'is_change', 'image_name', 'running'], flashes=True, omitted=False,
                                     get_reliability=False, get_pref_stim=True, exclude_omitted_from_pref_stim=False)

                cell_data = sdf[(sdf.cell_specimen_id == cell_specimen_id) & (sdf.is_change == False) & (sdf.pref_stim == True)]
                run_colors = [sns.color_palette()[3], sns.color_palette()[2]]
                for c, running in enumerate(np.sort(cell_data.running.unique())):
                    if len(cell_data[cell_data.running == running]) > 0:
                        ax[i + (n * 2)] = sf.plot_mean_trace_from_mean_df(cell_data[cell_data.running == running],
                                                                          frame_rate=analysis.ophys_frame_rate, ylabel=ylabel,
                                                                          legend_label=running, color=run_colors[c], interval_sec=0.5,
                                                                          xlims=window, ax=ax[i + (n * 2)])
                ax[i + (n * 2)].legend(fontsize='xx-small', title='running', title_fontsize='xx-small')
                ax[i + (n * 2)] = sf.plot_flashes_on_trace(ax[i + (n * 2)], analysis, window=window, trial_type=None, omitted=False, alpha=0.15, facecolor='gray')
                ax[i + (n * 2)].set_title(container_expts.loc[ophys_experiment_id].session_type[6:] + '\n image response')

                window = rp.get_default_omission_response_params()["window_around_timepoint_seconds"]
                cell_data = odf[(odf.cell_specimen_id == cell_specimen_id)]
                ax[i + (n * 3)] = sf.plot_mean_trace_from_mean_df(cell_data,
                                                                  frame_rate=analysis.ophys_frame_rate, ylabel=ylabel,
                                                                  legend_label=image_name, color='gray', interval_sec=1,
                                                                  xlims=window, ax=ax[i + (n * 3)])
                ax[i + (n * 3)] = sf.plot_flashes_on_trace(ax[i + (n * 3)], analysis, window=window, trial_type=None, omitted=True, alpha=0.15, facecolor='gray')
                ax[i + (n * 3)].set_title(container_expts.loc[ophys_experiment_id].session_type[6:] + '\n omission response')

                window = rp.get_default_trial_response_params()["window_around_timepoint_seconds"]
                cell_data = tdf[(tdf.cell_specimen_id == cell_specimen_id) & (tdf.go == True) & (tdf.pref_stim == True)]
                hit_colors = [sns.color_palette()[2], sns.color_palette()[3]]
                for c, hit in enumerate([True, False]):
                    if len(cell_data[cell_data.hit == hit]) > 0:
                        ax[i + (n * 4)] = sf.plot_mean_trace_from_mean_df(cell_data[cell_data.hit == hit],
                                                                          frame_rate=analysis.ophys_frame_rate, ylabel=ylabel,
                                                                          legend_label=hit, color=hit_colors[c], interval_sec=1,
                                                                          xlims=window, ax=ax[i + (n * 4)])
                ax[i + (n * 4)].legend(fontsize='xx-small', title='hit', title_fontsize='xx-small')
                ax[i + (n * 4)] = sf.plot_flashes_on_trace(ax[i + (n * 4)], analysis, window=window, trial_type='go', omitted=False, alpha=0.15, facecolor='gray')
                ax[i + (n * 4)].set_title(container_expts.loc[ophys_experiment_id].session_type[6:] + '\n change response')

                fig.tight_layout()
                fig.suptitle(str(cell_specimen_id) + '_' + dataset.metadata_string, x=0.5, y=1.01,
                             horizontalalignment='center')
        except Exception as e:
            print('problem for cell_specimen_id:', cell_specimen_id, ', ophys_experiment_id:', ophys_experiment_id)
            print(e)
    if save_figure:
        save_dir = utils.get_single_cell_plots_dir()
        utils.save_figure(fig, figsize, save_dir, 'across_session_responses', str(
            cell_specimen_id) + '_' + dataset.metadata_string + '_across_session_responses' + suffix)
        plt.close()
def get_multi_session_df(project_code,
                         session_number,
                         conditions,
                         data_type,
                         event_type,
                         time_window=[-3, 3.1],
                         interpolate=True,
                         output_sampling_rate=30,
                         response_window_duration=0.5,
                         use_extended_stimulus_presentations=False,
                         overwrite=False):
    """

    :param project_code:
    :param session_number:
    :param conditions:
    :param data_type:
    :param event_type:
    :param time_window:
    :param interpolate:
    :param output_sampling_rate:
    :param response_window_duration:
    :param use_extended_stimulus_presentations:
    :return:
    """
    # cant get prefered stimulus if images are not in the set of conditions
    if ('image_name' in conditions) or ('change_image_name' in conditions):
        get_pref_stim = True
    else:
        get_pref_stim = False
    print('get_pref_stim', get_pref_stim)

    cache_dir = loading.get_platform_analysis_cache_dir()
    cache = VisualBehaviorOphysProjectCache.from_s3_cache(cache_dir=cache_dir)
    print(cache_dir)
    experiments_table = cache.get_ophys_experiment_table()
    # dont include Ai94 experiments because they makes things too slow
    experiments_table = experiments_table[(experiments_table.reporter_line !=
                                           'Ai94(TITL-GCaMP6s)')]

    session_number = float(session_number)
    experiments = experiments_table[
        (experiments_table.project_code == project_code)
        & (experiments_table.session_number == session_number)].copy()
    print('session_types:', experiments.session_type.unique(),
          ' - there should only be one session_type per session_number')
    session_type = experiments.session_type.unique()[0]

    filename = loading.get_file_name_for_multi_session_df(
        data_type, event_type, project_code, session_type, conditions)
    mega_mdf_write_dir = loading.get_multi_session_df_dir(
        interpolate=interpolate,
        output_sampling_rate=output_sampling_rate,
        event_type=event_type)
    filepath = os.path.join(mega_mdf_write_dir, filename)

    if not overwrite:  # if we dont want to overwrite
        if os.path.exists(filepath):  # and file exists, dont regenerate
            print('multi_session_df exists for', filepath)
            print('not regenerating')
            process_data = False
        else:  # if file doesnt exist, create it
            print('creating multi session mean df for', filename)
            process_data = True
    else:  # if we do want to overwrite
        process_data = True  # regenerate and save
        print('creating multi session mean df for', filename)

    if process_data:
        mega_mdf = pd.DataFrame()
        for experiment_id in experiments.index.unique():
            try:
                print(experiment_id)
                # get dataset
                dataset = loading.get_ophys_dataset(
                    experiment_id,
                    get_extended_stimulus_presentations=
                    use_extended_stimulus_presentations)
                # get stimulus_response_df
                df = loading.get_stimulus_response_df(
                    dataset,
                    data_type=data_type,
                    event_type=event_type,
                    time_window=time_window,
                    interpolate=interpolate,
                    output_sampling_rate=output_sampling_rate,
                    load_from_file=True)
                # use response_window duration from stim response df if it exists
                if response_window_duration in df.keys():
                    response_window_duration = df.response_window_duration.values[
                        0]
                df['ophys_experiment_id'] = experiment_id
                # if using omissions, only include omissions where time from last change is more than 3 seconds
                if event_type == 'omissions':
                    df = df[df.time_from_last_change > 3]
                # modify columns for specific conditions
                if 'passive' in dataset.metadata['session_type']:
                    df['lick_on_next_flash'] = False
                    df['engaged'] = False
                    df['engagement_state'] = 'disengaged'
                if 'running_state' in conditions:  # create 'running_state' Boolean column based on threshold on mean_running_speed
                    df['running'] = [
                        True if mean_running_speed > 2 else False
                        for mean_running_speed in df.mean_running_speed.values
                    ]
                if 'pupil_state' in conditions:  # create 'pupil_state' Boolean column based on threshold on mean_pupil_
                    if 'mean_pupil_area' in df.keys():
                        df = df[df.mean_pupil_area.isnull() == False]
                        if len(df) > 100:
                            median_pupil_area = df.mean_pupil_area.median()
                            df['large_pupil'] = [
                                True if mean_pupil_area > median_pupil_area
                                else False for mean_pupil_area in
                                df.mean_pupil_area.values
                            ]
                if 'pre_change' in conditions:
                    df = df[df.pre_change.isnull() == False]
                # get params for mean df creation from stimulus_response_df
                output_sampling_rate = df.frame_rate.unique()[0]

                mdf = ut.get_mean_df(
                    df,
                    conditions=conditions,
                    frame_rate=output_sampling_rate,
                    window_around_timepoint_seconds=time_window,
                    response_window_duration_seconds=response_window_duration,
                    get_pref_stim=get_pref_stim,
                    exclude_omitted_from_pref_stim=True)
                if 'correlation_values' in mdf.keys():
                    mdf = mdf.drop(columns=['correlation_values'])
                mdf['ophys_experiment_id'] = experiment_id
                print('mean df created for', experiment_id)
                mega_mdf = pd.concat([mega_mdf, mdf])
            except Exception as e:  # flake8: noqa: E722
                print(e)
                print('problem for', experiment_id)

        if 'level_0' in mega_mdf.keys():
            mega_mdf = mega_mdf.drop(columns='level_0')
        if 'index' in mega_mdf.keys():
            mega_mdf = mega_mdf.drop(columns='index')

        # if file of the same name exists, delete & overwrite to prevent files from getting huge
        if os.path.exists(filepath):
            os.remove(filepath)
        print('saving multi session mean df as ', filename)
        mega_mdf.to_hdf(filepath, key='df')
        print('saved to', mega_mdf_write_dir)

        return mega_mdf

    else:
        print('multi_session_df not created')
def plot_experiment_summary_figure(analysis, save_dir=None):
    use_events = analysis.use_events
    if use_events:
        traces = analysis.dataset.events_array.copy()
        suffix = '_events'
    else:
        traces = analysis.dataset.dff_traces_array.copy()
        suffix = ''

    interval_seconds = 600
    ophys_frame_rate = int(analysis.ophys_frame_rate)

    figsize = [2 * 11, 2 * 8.5]
    fig = plt.figure(figsize=figsize, facecolor='white')

    ax = placeAxesOnGrid(fig, dim=(1, 1), xspan=(.8, 0.95), yspan=(0, .3))
    table_data = format_table_data(analysis.dataset)
    xtable = ax.table(cellText=table_data.values, cellLoc='left', rowLoc='left', loc='center', fontsize=12)
    xtable.scale(1.5, 3)
    ax.axis('off')

    ax = placeAxesOnGrid(fig, dim=(1, 1), xspan=(.0, .22), yspan=(0, .27))
    # metrics = dataset.cell_indices
    metrics = np.empty(len(analysis.dataset.cell_indices))
    metrics[:] = -1
    cell_list = analysis.dataset.cell_indices
    plot_metrics_mask(analysis.dataset, metrics, cell_list, 'cell masks', max_image=True, cmap='hls', ax=ax, save=False,
                      colorbar=False)
    # ax.imshow(analysis.dataset.max_projection, cmap='gray', vmin=0, vmax=np.amax(analysis.dataset.max_projection))
    ax.set_title(analysis.dataset.experiment_id)
    ax.axis('off')

    upper_limit, time_interval, frame_interval = get_upper_limit_and_intervals(traces,
                                                                               analysis.dataset.ophys_timestamps,
                                                                               analysis.ophys_frame_rate)

    ax = placeAxesOnGrid(fig, dim=(1, 1), xspan=(.22, 0.9), yspan=(0, .3))
    # ax = plot_traces_heatmap(analysis.dataset, ax=ax, use_events=use_events)
    ax = plot_sorted_traces_heatmap(analysis.dataset, analysis, ax=ax, use_events=use_events)
    ax.set_xticks(np.arange(0, upper_limit, interval_seconds * ophys_frame_rate))
    ax.set_xticklabels(np.arange(0, upper_limit / ophys_frame_rate, interval_seconds))
    ax.set_xlabel('time (seconds)')
    ax.set_title(analysis.dataset.analysis_folder)

    ax = placeAxesOnGrid(fig, dim=(1, 1), xspan=(.22, 0.8), yspan=(.26, .41))
    ax = plot_run_speed(analysis.dataset.running_speed.running_speed, analysis.dataset.stimulus_timestamps, ax=ax,
                        label=True)
    ax.set_xlim(time_interval[0], np.uint64(upper_limit / ophys_frame_rate))
    ax.set_xticks(np.arange(interval_seconds, upper_limit / ophys_frame_rate, interval_seconds))
    ax.set_xlabel('time (seconds)')

    ax = placeAxesOnGrid(fig, dim=(1, 1), xspan=(.22, 0.8), yspan=(.37, .52))
    ax = plot_hit_false_alarm_rates(analysis.dataset.trials, ax=ax)
    ax.set_xlim(time_interval[0], np.uint64(upper_limit / ophys_frame_rate))
    ax.set_xticks(np.arange(interval_seconds, upper_limit / ophys_frame_rate, interval_seconds))
    ax.legend(loc='upper right', ncol=2, borderaxespad=0.)
    ax.set_xlabel('time (seconds)')

    ax = placeAxesOnGrid(fig, dim=(1, 1), xspan=(.0, .22), yspan=(.25, .8))
    ax = plot_lick_raster(analysis.dataset.trials, ax=ax, save_dir=None)

    ax = placeAxesOnGrid(fig, dim=(1, 4), xspan=(.2, .8), yspan=(.5, .8), wspace=0.35)
    try:
        mdf = ut.get_mean_df(analysis.trials_response_df, analysis,
                             conditions=['cell_specimen_id', 'change_image_name', 'behavioral_response_type'])
        ax = plot_mean_trace_heatmap(mdf, condition='behavioral_response_type',
                                     condition_values=['HIT', 'MISS', 'CR', 'FA'], ax=ax, save_dir=None,
                                     use_events=use_events, window=analysis.trial_window)
    except BaseException:
        pass

    ax = placeAxesOnGrid(fig, dim=(1, 1), xspan=(.78, 0.97), yspan=(.3, .8))
    mdf = ut.get_mean_df(analysis.trials_response_df, analysis, conditions=['cell_specimen_id', 'change_image_name'])
    ax = plot_mean_image_response_heatmap(mdf, title=None, ax=ax, save_dir=None, use_events=use_events)

    # fig.canvas.draw()
    fig.tight_layout()

    if save_dir:
        save_figure(fig, figsize, save_dir, 'experiment_summary_figures',
                    str(analysis.dataset.experiment_id) + '_experiment_summary' + suffix)
def get_multi_session_mean_df(experiment_ids,
                              analysis_cache_dir,
                              df_name,
                              conditions=[
                                  'cell_specimen_id', 'change_image_name',
                                  'behavioral_response_type'
                              ],
                              flashes=False,
                              use_events=False,
                              omitted=False,
                              get_reliability=False,
                              get_pref_stim=True,
                              exclude_omitted_from_pref_stim=True,
                              use_sdk_dataset=True):
    experiments_table = loading.get_filtered_ophys_experiment_table()
    mega_mdf = pd.DataFrame()
    for experiment_id in experiment_ids:
        print(experiment_id)
        # try:
        if use_sdk_dataset:
            dataset = loading.get_ophys_dataset(experiment_id)
        else:
            dataset = VisualBehaviorOphysDataset(experiment_id,
                                                 analysis_cache_dir)
        analysis = ResponseAnalysis(dataset,
                                    use_events=use_events,
                                    overwrite_analysis_files=False,
                                    use_extended_stimulus_presentations=True)
        df = analysis.get_response_df(df_name)
        df['ophys_experiment_id'] = dataset.ophys_experiment_id
        df['project_code'] = experiments_table.loc[experiment_id].project_code
        df['session_type'] = experiments_table.loc[experiment_id].session_type
        # if 'engaged' in conditions:
        #     df['engaged'] = [True if reward_rate > 2 else False for reward_rate in df.reward_rate.values]
        if 'running' in conditions:
            df['running'] = [
                True if window_running_speed > 5 else False
                for window_running_speed in df.window_running_speed.values
            ]
        # if 'large_pupil' in conditions:
        #     if 'mean_pupil_area' in df.keys():
        #         df = df[df.mean_pupil_area.isnull() == False]
        #         if len(df) > 100:
        #             median_pupil_area = df.mean_pupil_area.median()
        #             df['large_pupil'] = [True if mean_pupil_area > median_pupil_area else False for mean_pupil_area in
        #                                  df.mean_pupil_area.values]
        mdf = ut.get_mean_df(
            df,
            analysis,
            conditions=conditions,
            get_pref_stim=get_pref_stim,
            flashes=flashes,
            omitted=omitted,
            get_reliability=get_reliability,
            exclude_omitted_from_pref_stim=exclude_omitted_from_pref_stim)
        mdf['ophys_experiment_id'] = dataset.ophys_experiment_id
        dataset.metadata['reporter_line'] = dataset.metadata['reporter_line'][
            0]
        dataset.metadata['driver_line'] = dataset.metadata['driver_line'][0]
        metadata = pd.DataFrame(dataset.metadata, index=[experiment_id])
        mdf = ut.add_metadata_to_mean_df(mdf, metadata)
        mega_mdf = pd.concat([mega_mdf, mdf])
        # except Exception as e:  # flake8: noqa: E722
        #     print(e)
        #     print('problem for', experiment_id)
    if use_events:
        suffix = '_events'
    else:
        suffix = ''
    if 'level_0' in mega_mdf.keys():
        mega_mdf = mega_mdf.drop(columns='level_0')
    if 'index' in mega_mdf.keys():
        mega_mdf = mega_mdf.drop(columns='index')

    mega_mdf_write_dir = os.path.join(analysis_cache_dir,
                                      'multi_session_summary_dfs')
    if not os.path.exists(mega_mdf_write_dir):
        os.makedirs(mega_mdf_write_dir)

    if len(conditions) == 5:
        filename = 'mean_' + df_name + '_' + conditions[1] + '_' + conditions[2] + '_' + conditions[3] + '_' + \
                   conditions[4] + suffix + '.h5'
    elif len(conditions) == 4:
        filename = 'mean_' + df_name + '_' + conditions[1] + '_' + conditions[
            2] + '_' + conditions[3] + suffix + '.h5'
    elif len(conditions) == 3:
        filename = 'mean_' + df_name + '_' + conditions[1] + '_' + conditions[
            2] + suffix + '.h5'
    elif len(conditions) == 2:
        filename = 'mean_' + df_name + '_' + conditions[1] + suffix + '.h5'
    elif len(conditions) == 1:
        filename = 'mean_' + df_name + '_' + conditions[0] + suffix + '.h5'

    print('saving multi session mean df to ', filename)
    mega_mdf.to_hdf(os.path.join(mega_mdf_write_dir, filename), key='df')
    print('saved')
    return mega_mdf
def plot_experiment_summary_figure(experiment_id, save_figure=True):

    dataset = loading.get_ophys_dataset(experiment_id)
    analysis = ra.ResponseAnalysis(dataset, use_events=True)

    fig, ax, figsize = make_fig_ax()
    ax['0_0'] = ep.plot_max_intensity_projection_for_experiment(experiment_id,
                                                                ax=ax['0_0'])
    ax['0_0'].set_title('max projection')
    ax['0_1'] = ep.plot_valid_segmentation_mask_outlines_per_cell_for_experiment(
        experiment_id, ax=ax['0_1'])
    # ax['0_0'].set_title('max projection')
    ax['0_2'] = ep.plot_valid_and_invalid_segmentation_mask_overlay_per_cell_for_experiment(
        experiment_id, ax=ax['0_2'])
    ax['0_2'].set_title('red = valid ROIs, blue = invalid ROIs')
    ax['0_3:'] = ep.plot_motion_correction_and_population_average(
        experiment_id, ax=ax['0_3:'])

    ax['1_0'] = ep.plot_average_image_for_experiment(experiment_id,
                                                     ax=ax['1_0'])
    ax['1_1'] = ep.plot_average_image_for_experiment(experiment_id,
                                                     ax=ax['1_1'])
    try:
        ax['1_2'] = ep.plot_remaining_decrosstalk_masks_for_experiment(
            experiment_id, ax=ax['1_2'])
    except BaseException:
        print('no decrosstalk for experiment', experiment_id)
    ax['1_3:'] = ep.plot_behavior_timeseries_for_experiment(experiment_id,
                                                            ax=ax['1_3:'])

    # ax['2_0'] = population_image_selectivity(experiment_id, ax=ax['2_0'])
    # ax['2_0'] = ep.plot_average_image_for_experiment(experiment_id, ax=ax['2_1'])

    ax['2_2'] = ep.plot_cell_snr_distribution_for_experiment(experiment_id,
                                                             ax=ax['2_2'])
    ax['2_3:'] = ep.plot_traces_heatmap_for_experiment(experiment_id,
                                                       ax=ax['2_3:'])

    df_name = 'trials_response_df'
    df = analysis.get_response_df(df_name)
    mean_df = ut.get_mean_df(df,
                             analysis=analysis,
                             conditions=['cell_specimen_id', 'go'],
                             flashes=False,
                             omitted=False,
                             get_reliability=False,
                             get_pref_stim=False,
                             exclude_omitted_from_pref_stim=True)
    ax['3_0'] = ep.plot_population_average_for_experiment(experiment_id,
                                                          df,
                                                          mean_df,
                                                          df_name,
                                                          color=None,
                                                          label=None,
                                                          ax=ax['3_0'])
    ax['3_0'].set_xlim(-2.5, 2.8)

    df_name = 'omission_response_df'
    df = analysis.get_response_df(df_name)
    mean_df = ut.get_mean_df(df,
                             analysis=analysis,
                             conditions=['cell_specimen_id'],
                             flashes=False,
                             omitted=True,
                             get_reliability=False,
                             get_pref_stim=False,
                             exclude_omitted_from_pref_stim=False)
    ax['3_1'] = ep.plot_population_average_for_experiment(experiment_id,
                                                          df,
                                                          mean_df,
                                                          df_name,
                                                          color=None,
                                                          label=None,
                                                          ax=ax['3_1'])
    ax['3_1'].set_xlim(-2.5, 2.8)

    df_name = 'trials_run_speed_df'
    df = analysis.get_response_df(df_name)
    df['condition'] = True
    mean_df = ut.get_mean_df(df,
                             analysis=analysis,
                             conditions=['condition', 'go'],
                             flashes=False,
                             omitted=True,
                             get_reliability=False,
                             get_pref_stim=False,
                             exclude_omitted_from_pref_stim=False)
    ax['4_0'] = ep.plot_population_average_for_experiment(
        experiment_id,
        df,
        df,
        df_name,
        trace_type='trace',
        color=sns.color_palette()[4],
        label=None,
        ax=ax['4_0'])
    ax['4_0'].set_ylabel('run speed (cm/s)')
    ax['4_0'].set_xlim(-2.5, 2.8)

    df_name = 'omission_run_speed_df'
    df = analysis.get_response_df(df_name)
    df['condition'] = True
    mean_df = ut.get_mean_df(df,
                             analysis=analysis,
                             conditions=['condition'],
                             flashes=False,
                             omitted=False,
                             get_reliability=False,
                             get_pref_stim=True,
                             exclude_omitted_from_pref_stim=True)
    ax['4_1'] = ep.plot_population_average_for_experiment(
        experiment_id,
        df,
        df,
        df_name,
        trace_type='trace',
        color=sns.color_palette()[4],
        label=None,
        ax=ax['4_1'])
    ax['4_1'].set_ylabel('run speed (cm/s)')
    ax['4_1'].set_xlim(-2.5, 2.8)

    xlim_seconds = [int(10 * 60), int(15 * 60)]
    ax['3_3:'] = ep.plot_high_low_snr_trace_examples(experiment_id,
                                                     xlim_seconds=xlim_seconds,
                                                     ax=ax['3_3:'])

    ax['4_3:'] = ep.plot_behavior_timeseries_for_experiment(
        experiment_id, xlim_seconds=xlim_seconds, ax=ax['4_3:'])

    fig.tight_layout()
    title = dataset.metadata_string
    plt.suptitle(title, x=0.5, y=.91, fontsize=20)
    if save_figure:
        # save_dir = r'\\allen\programs\braintv\workgroups\nc-ophys\visual_behavior\qc_plots\experiment_plots'
        save_dir = loading.get_experiment_plots_dir()
        utils.save_figure(fig, figsize, save_dir, 'experiment_summary_figure',
                          title)