def plot_first_response(dest_dir_appdx='../',
                        which_paradigms=['DAC1', 'DAC2']):
    """
    This function plots a comparison between first responses in the different 
    channels. This information is taken from MUA summary (AvgTimetoFirstSpike) 
    and LFP summary (`Peak_neg_ts`). These plots are used in the panel for 
    defining thalamic channels using the function below. `which_paradigms` 
    selects the paradigms to plot, defaults to [`DAC1`, `DAC2`]. If all paradigms
    should be plotted, you could pass const.ALL_PARADIGMS. The plots are saved 
    as usual at P["outputPath"]/dest_dir_appdx/
    """
    data = fetch(paradigms=which_paradigms,
                 collapse_ctx_chnls=True,
                 collapse_th_chnls=False,
                 drop_not_assigned_chnls=False)

    for mouse in const.ALL_MICE:
        if mouse not in const.PARAD_ORDER:
            print(
                f'`{mouse}` not found in constant PARAD_ORDER (MUA_constants.py). Check.'
            )
            exit(1)
        for i, parad in enumerate(const.PARAD_ORDER[mouse]):
            if parad not in which_paradigms:
                continue

            if parad not in ['MS', 'DOC1', 'DOC2']:
                peak_stim = 'Deviant'
            else:
                peak_stim = 'C1' if parad == 'MS' else 'Standard'
            fig, axes = plt.subplots(nrows=2, figsize=(10, 3.4))
            fig.subplots_adjust(right=.975,
                                top=.94,
                                left=.02,
                                bottom=.13,
                                hspace=.15)

            for which_ax, ax in enumerate(axes):
                # get the data (lfp or mua time stamp)
                if which_ax == 0:
                    lfp_summ = slice_data(data, [mouse], [parad], [peak_stim],
                                          lfp_summary=True,
                                          drop_labels=True)[0]
                    first_ts = lfp_summ.loc[:, 'Peak_neg_ts'].iloc[
                        -10:].sort_values()
                    first_ts_G = lfp_summ.loc['G', 'Peak_neg_ts']
                elif which_ax == 1:
                    mua_summ = slice_data(data, [mouse], [parad], [peak_stim],
                                          mua_summary=True,
                                          drop_labels=True)[0]
                    first_ts = mua_summ.loc[:, 'AvgTimetoFirstSpike'].iloc[
                        -10:].sort_values()
                    first_ts_G = mua_summ.loc['G', 'AvgTimetoFirstSpike']

                ax.tick_params(left=False, labelleft=False)
                if which_ax == 0:
                    ax.tick_params(bottom=False, labelbottom=False)
                [sp.set_visible(False) for sp in ax.spines.values()]
                ax.patch.set_facecolor('grey')
                ax.patch.set_alpha(.16)
                ax.hlines((0), -10, 200, color='grey')

                tit = 'Peak negative LFP time stamp' if which_ax == 0 else 'Avg first spike time stamp'
                nan_chnls = [
                    chnl + 1 for chnl in first_ts.index[first_ts.isna()]
                ]
                if nan_chnls:
                    tit = f'{tit}      (NA channels: {nan_chnls})'
                if which_ax == 1:
                    zero_chnls = (first_ts.index[first_ts == 0]).values
                    first_ts[zero_chnls] = np.nan
                    tit = f'{tit}      (no spike channels: {list(zero_chnls+1)})'
                ax.set_title(tit, loc='left', pad=2, size=9)

                ax.set_ylim((-1.5, 1.5))
                ax.set_xlim((-10, 200))
                xts = (0, 4, 8, 12, 16, 20, 30, 40, 60, 80, 100, 150, 200)
                ax.set_xticks(xts)
                ax.xaxis.grid(True, which='major')
                if which_ax == 1:
                    ax.set_xlabel('[ms] post stimulus', labelpad=2, size=9)

                step = .4
                ycoords = []
                close_coords = []
                for i in range(len(first_ts)):
                    if i and first_ts.iloc[i] - first_ts.iloc[i - 1] < 5:
                        step *= -1
                        if step in close_coords:
                            step = step + .4 if step > 0 else step - .4
                        if len(close_coords) == 0:
                            close_coords.append(step * -1)
                        close_coords.append(step)
                    else:
                        step = .4
                        close_coords = []
                    ycoords.append(step)

                [
                    ax.vlines(x,
                              0,
                              y,
                              linewidth=1,
                              color=const.GENERAL_CMAP['Th'],
                              alpha=.7) for y, x in zip(ycoords, first_ts)
                ]
                [
                    ax.annotate(chnl + 1, (x, y),
                                va='center',
                                ha='center',
                                zorder=20)
                    for y, (chnl, x) in zip(ycoords, first_ts.iteritems())
                ]
                ax.vlines(first_ts_G,
                          -.4,
                          .4,
                          linewidth=1,
                          color=const.GENERAL_CMAP['G'],
                          alpha=.7)
                ax.annotate('G', (first_ts_G, 0),
                            va='center',
                            ha='center',
                            zorder=20,
                            size=12)
            f = f'{const.P["outputPath"]}/{dest_dir_appdx}/{mouse}_{parad}_{peak_stim}_first_ts.png'
            fig.savefig(f)
def thalamic_mapping_panel(dest_dir_appdx='../',
                           which_paradigms=['DAC1', 'DAC2'],
                           anatomy_dir=None,
                           ts_plots_dir=None):
    """
    This function creates a panel similar in style to the one produced in 
    cortical_mapping_panel(). The panel has the LFP respones plotted for the
    thalamic channels, firingrate plots, shows the injection slice, and the 
    timestamp of first response. Again, which_paradigms controls the paradigms
    being plotted, by default DAC1, DAC2. Both `anatomy_dir` and `ts_plots_dir`
    need to be passed for the panel to include these. Note that both need to be
    passed, even if you're note interested in one of them in order for them to 
    be drawn. The `anatomy_dir` should have files of the form `mouse.png`, eg.
    `mGE82.png`, ts_plots_dir is simply the output dir of the previous function,
    first_response(). Plots are saved as usual at P["outputPath]/dest_dir_appdx.
    Note that the function expects that cortical mapping to be done. The same 
    file as the one used for cortical mapping is referenced: 
    P["outputPath"]/../chnls_map.csv. Subsequent scripts expect `VPM` as the 
    label of this thalamic mapping.
    """
    def make_plot(lfp, frates, lfp_summ, mua_summ, mouse, parad, stim_t):
        thal_lfp = lfp.iloc[-10:, :-50] * 1_000_000
        ctx_lfp = lfp.iloc[:4, :-50] * 1_000_000
        frates = frates.loc[range(22, 32), :]
        x_time = lfp.columns[:-50].values.astype(float)

        # init figure
        fig = plt.figure(figsize=(14, 10.2))
        gs = fig.add_gridspec(
            7,
            3,
            width_ratios=[.5, .35, .15],
            height_ratios=[.25, .025, .15, .15, .15, .15, .15],
            hspace=0,
            wspace=.12,
            right=.975,
            top=.95,
            left=.02,
            bottom=.05)
        lfp_ax_ctx = fig.add_subplot(gs[0, 0])
        frate_ax = fig.add_subplot(gs[0, 1])
        assig_ax = fig.add_subplot(gs[0, 2])
        lfp_axl = [fig.add_subplot(gs[i, 0]) for i in range(2, 7)]
        lfp_axr = [fig.add_subplot(gs[i, 1:]) for i in range(2, 7)]
        all_axs = lfp_axl + lfp_axr + [lfp_ax_ctx, frate_ax, assig_ax]
        lfp_axs = lfp_axl + lfp_axr + [
            lfp_ax_ctx
        ]  # 0,1,2,3,4 left plots, 5,6,7,8,9 right plots, 10 ctx plot
        # clean axes
        [
            ax.tick_params(left=False,
                           bottom=False,
                           labelleft=False,
                           labelbottom=False) for ax in all_axs
        ]

        # iterate bottom plots block and top left (all lfp's), setup axis
        for which_ax, ax in enumerate(lfp_axs):
            # general
            [sp.set_visible(False) for sp in ax.spines.values()]
            ax.patch.set_facecolor('grey')
            ax.patch.set_alpha(.16)
            ax.hlines(0, -5, 20, color='grey')

            # x axis
            xleft, xright = -0.05, 0.2
            ax.xaxis.grid(True, which='major')
            ax.set_xlim((xleft, xright))
            xticks = (xleft, 0, 0.05, 0.1, 0.15, xright)
            ax.set_xticks(xticks)
            # bottom plots
            if which_ax in (4, 9):
                ax.tick_params(labelbottom=True)
                ax.set_xticklabels(xticks, size=8)
                ax.set_xlabel('[s]', size=8)

            # y axis
            if which_ax != 10:
                ybot, ytop = -25, 25
            else:  # ctx plot
                ybot, ytop = -180, 90
            ax.yaxis.grid(True, which='major')
            ax.set_ylim((ybot, ytop))
            yts = np.concatenate(
                [np.arange(0, ybot, -15)[1:],
                 np.arange(0, ytop, 15)])
            ax.set_yticks(yts)
            # right plots
            if which_ax in (5, 6, 7, 8, 9):
                ax.tick_params(labelright=True)
                ax.set_yticklabels((-15, 0, 15), size=8)
            # ctx plots
            elif which_ax == 10:
                ax.tick_params(labelright=True)
                ax.set_yticklabels(
                    [yt if yt in (-150, -75, 0, 60) else None for yt in yts],
                    size=8)
            # ctx plot and middle, right plot
            if which_ax in (7, 10):
                ax.annotate('[uV]', (.263, 0), size=8, annotation_clip=False)
            if which_ax == 10:
                ax.set_title(f'{mouse}-{parad}-{stim_t}')

        # draw lfp's
        for (chnl, dat), ax in zip(thal_lfp.iterrows(), lfp_axl + lfp_axr):
            ax.set_ylabel(chnl + 1,
                          fontsize=10,
                          rotation=0,
                          ha='right',
                          va='center',
                          labelpad=2)
            ax.plot(x_time,
                    dat.values,
                    clip_on=False,
                    color=const.REGION_CMAP['Th'])
            # ax.vlines((lfp_summ.loc[chnl, 'Peak_neg_ts']/1000), ybot, ytop)
            # ax.vlines((mua_summ.loc[chnl, 'AvgTimetoFirstSpike']/1000), ybot, ytop, color='green')
        [
            lfp_ax_ctx.plot(x_time,
                            lfp,
                            label=region,
                            color=const.REGION_CMAP[region],
                            clip_on=False)
            for region, lfp in ctx_lfp.iterrows()
        ]
        lfp_ax_ctx.legend()

        # setup the assignment plot
        assig_ax.tick_params(left=True, labelleft=True, pad=9.5)
        assig_ax.set_xlim((0, 1))
        assig_ax.set_ylim((32.5, 22.5))
        assig_ax.set_yticks(np.arange(23, 33))

        # somewhat cryptic... this checks if the cortical map in the previous
        # paradigm differs from the current one. If True, the current plot
        # gets annotated with the change in cortical mapping
        first_chnls = pd.Series({
            mouse_parad: mapping.index[this_map == 'SG'][0]
            for mouse_parad, this_map in mapping.iteritems()
        })
        mouse_parad = f'{mouse}-{parad}'
        mouse_first_chnls = first_chnls[[
            entr for entr in first_chnls.index if mouse in entr
        ]]

        current_first_chnl = mouse_first_chnls.loc[mouse_parad]
        if parad != const.PARAD_ORDER[mouse][0]:  # if not the first paradigm
            before_first_chnl = mouse_first_chnls[
                mouse_first_chnls.index.get_loc(mouse_parad) - 1]
            moved_by = before_first_chnl - current_first_chnl
            if moved_by:
                up_dwn = 'up' if moved_by > 0 else 'down'
                ann = f'Cortical map moved\n{up_dwn} by {abs(moved_by)}!'
                assig_ax.annotate(ann, (.1, 22),
                                  annotation_clip=False,
                                  fontsize=13)

        th_mapping = mapping.iloc[-10:, :].loc[:, [mouse_parad]]
        cols = [const.REGION_CMAP[region[0]] for region in th_mapping.values]

        # define how each layer is colored, label
        assig_ax.barh(np.arange(23, 33),
                      1,
                      height=1,
                      edgecolor='k',
                      color=cols,
                      alpha=.6)
        [
            assig_ax.annotate(th_mapping.values[i][0], (.2, i + 23 + .2),
                              fontsize=11) for i in range(len(th_mapping))
        ]

        frate_ax.tick_params(top=True, labeltop=True, right=True)
        frate_ax.xaxis.set_label_position('top')
        frate_ax.set_xticks((-.05, 0, .1, .2))
        frate_ax.set_xticklabels((-.05, 0, .1, .2), size=8)
        frate_ax.set_xlabel('[ms]', size=8)

        frate_ax.set_yticks(np.arange(23, 33))
        frate_ax.set_ylim((22.5, 32.5))

        frate_ax.imshow(frates,
                        cmap='gnuplot',
                        aspect='auto',
                        extent=[-.0525, 0.2025, 22.5, 32.5],
                        vmin=0,
                        vmax=15)
        frate_ax.vlines([-.002], 22.5, 33.5, color='w', alpha=.6)
        # plt.show()
        return fig

    data = fetch(paradigms=which_paradigms,
                 collapse_ctx_chnls=True,
                 collapse_th_chnls=False,
                 drop_not_assigned_chnls=False)
    # get the current assignment
    mapping = pd.read_csv(f'{const.P["outputPath"]}/../chnls_map.csv',
                          index_col=0).reset_index(drop=True)

    for mouse in const.ALL_MICE:
        if mouse not in const.PARAD_ORDER:
            print(
                f'`{mouse}` not found in constant PARAD_ORDER (MUA_constants.py). Check.'
            )
            exit(1)
        for i, parad in enumerate(const.PARAD_ORDER[mouse]):
            if parad not in which_paradigms:
                continue

            if parad not in ['MS', 'DOC1', 'DOC2']:
                peak_stim = 'Deviant'
            else:
                peak_stim = 'C1' if parad == 'MS' else 'Standard'

            frates = slice_data(data, [mouse], [parad], [peak_stim],
                                firingrate=True,
                                frate_noise_subtraction='paradigm_wise',
                                drop_labels=True)[0]

            lfp = slice_data(data, [mouse], [parad], [peak_stim],
                             lfp=True,
                             drop_labels=True)[0]
            lfp_summ = slice_data(data, [mouse], [parad], [peak_stim],
                                  lfp_summary=True,
                                  drop_labels=True)[0]
            mua_summ = slice_data(data, [mouse], [parad], [peak_stim],
                                  mua_summary=True,
                                  drop_labels=True)[0]

            fig = make_plot(lfp, frates, lfp_summ, mua_summ, mouse, parad,
                            peak_stim)
            f = f'{const.P["outputPath"]}/{dest_dir_appdx}/{mouse}_{i+1}_{parad}.png'
            fig.savefig(f)

            if anatomy_dir and ts_plots_dir:
                plot = Image.open(f)
                anat = Image.open(f'{anatomy_dir}/{mouse}.png')
                ts_plot = Image.open(
                    f'{ts_plots_dir}/{mouse}_{parad}_{peak_stim}_first_ts.png')

                final_img = Image.new(
                    'RGB',
                    (plot.width + anat.width, anat.height + ts_plot.height),
                    color='white')
                final_img.paste(plot, (0, 0))  # upper left
                final_img.paste(anat, (plot.width, 0))  # upper right
                final_img.paste(ts_plot,
                                (plot.width, anat.height))  # upper right
                final_img.save(f)
示例#3
0
def onset_offset_response(plots_dest_dir_appdx='',
                          csv_dest_dir_appdx='',
                          generate_plots=True,
                          single_channels=True,
                          draw_gmm_fit=True):
    """This function presents the start of the onset-offset analysis pipeline.
    This functions iterates all mice, all paradigms, all stimulus types as usual
    then fetches the dataframe with all the negative time stamps. Out of this,
    a histogram of the 0-20ms trace is generated (200 bins, 0.1ms timeframe 
    bins). The bin counts serve as the input feature for the classifier. The 
    final CSV with the computed bin counts will be saved at 
    P["outputPath"]/csv_dest_dir_appdx/onset_offset_spikebins_*_.csv'. If 
    `csv_dest_dir_appdx` is None, no data is saved. In addition, 
    `generate_plots` gives the option to draw the histograms.`draw_gmm_fit` will
    fit a Gaussian Mixture Model to the histogram with 10 components. When 
    `single_channels` is True (default), the 32 channels are not collpassed into
    the previously defined region mapping. The classifier is designed to 
    classify single channels, not collapsed regions. Setting this to False is 
    implemented for plotting region histograms. The plots are saved at 
    P["outputPath"]/plots_dest_dir_appdx/onset_offset_*_*_.png'. The produced
    histogram data is also returned besides being saved."""

    # get all the available data from the output dir
    if single_channels:
        data = fetch()
        which_region = 'channels'
        plt_spacers = {
            'hspace': 0,
            'right': .97,
            'top': .96,
            'left': .1,
            'bottom': .07
        }
    else:
        data = fetch(collapse_ctx_chnls=True,
                     collapse_th_chnls=True,
                     drop_not_assigned_chnls=True)
        which_region = 'regions'
        plt_spacers = {
            'hspace': 0,
            'right': .97,
            'top': .85,
            'left': .1,
            'bottom': .22
        }

    # due to different base level actitvity, set heatmap vmax seperately
    vmaxs = {
        'mGE82': 20,
        'mGE83': 40,
        'mGE84': 40,
        'mGE85': 40,
    }

    # iter the usual dimensions
    nspikes = 0
    all_spike_bins = []
    for m_id in const.ALL_MICE:
        for parad in const.ALL_PARADIGMS:
            for stim_t in const.ALL_STIMTYPES:
                key = '-'.join([m_id, parad, stim_t])
                # not all combinations of paradigm/stimtype exist
                if key not in data.keys():
                    continue

                # get the negative sptike time stamps (a MultiIndex DataFrame)
                spikes = slice_data(data,
                                    m_id,
                                    parad,
                                    stim_t,
                                    neg_spikes=True,
                                    drop_labels=True)[0].sort_index(axis=1)
                nspikes += np.count_nonzero(spikes)
                if not single_channels:
                    spikes = spikes.reindex(['SG', 'G', 'IG', 'dIG', 'VPM'],
                                            axis=1,
                                            level=0)

                if generate_plots:
                    # init plot with fitting size (height depends on n channels/ regions)
                    nregions = len(spikes.columns.unique(0))
                    fig, axes = plt.subplots(nrows=nregions,
                                             figsize=(6, .4 * nregions))
                    fig.subplots_adjust(**plt_spacers)
                    [
                        ax.tick_params(bottom=False,
                                       left=False,
                                       labelbottom=False,
                                       labelleft=False)
                        for ax in axes.flatten()
                    ]

                    axes[0].set_title(key)
                else:
                    # dummy
                    axes = range(len(spikes.columns.unique(0)))

                # # add empty channels as well so that the total samples is always consistent
                no_spike_chnls = [
                    chnl for chnl in range(1, 33)
                    if chnl not in spikes.columns.unique(0)
                ]
                for region in no_spike_chnls:
                    lbl = key + f'-{region:0>2d}' if single_channels else key + f'-{region}'
                    all_spike_bins.append(
                        pd.Series(np.zeros(200, dtype=int), name=lbl))

                for region, ax in zip(spikes.columns.unique(0), axes):
                    # get the spike timestamp data sliced to the channel/ region
                    region_spikes = spikes[region].values.flatten()
                    # set 0 to nan _> ignored by np.hist
                    region_spikes[region_spikes == 0] = np.nan
                    hist = np.histogram(region_spikes,
                                        bins=2000,
                                        range=(-50, 200))

                    # slice to 0-20ms bins
                    start, stop = 400, 600  # relative to 2000 bins from -50-200
                    spike_bins = hist[0][np.newaxis, start:stop]

                    if spikes.shape[0] != 200:
                        # norm spike counts to 200 trails
                        spike_bins = spike_bins / spikes.shape[0]
                        spike_bins = (spike_bins * 200).astype(int)

                    lbl = key + f'-{region:0>2d}' if single_channels else key + f'-{region}'
                    all_spike_bins.append(pd.Series(spike_bins[0], name=lbl))

                    if not generate_plots:
                        continue

                    # draw the heatmap, setup axis
                    ax.imshow(spike_bins,
                              aspect='auto',
                              extent=(hist[1][start], hist[1][stop], 0, 1),
                              vmin=0,
                              vmax=vmaxs[m_id])
                    ax.set_ylabel(region, rotation=0, va='center', labelpad=20)
                    ax.set_ylim(0, .4)

                    xt = [0, 2, 4, 6, 8, 10, 12, 14, 16, 18, 20]
                    ax.set_xticks(xt)
                    ax.set_xlim(xt[0], xt[-1])

                    if draw_gmm_fit and (spike_bins > 1).sum() > 15:
                        model = bgmm(10,
                                     covariance_type='diag',
                                     random_state=1,
                                     mean_prior=(8, ),
                                     covariance_prior=(.1, ),
                                     degrees_of_freedom_prior=10)

                        post_stim = np.logical_and(region_spikes > 0,
                                                   region_spikes < 20)
                        model.fit(region_spikes[post_stim, np.newaxis])

                        x = np.linspace(-50, 200, 2000).reshape(2000, 1)
                        logprob = model.score_samples(x)
                        pdf = np.exp(logprob)
                        ax.plot(x, pdf, '-w', linewidth=.8)

                    # the last plot gets a red stimulus indication bar
                    if region == spikes.columns.unique(0)[-1]:
                        ax.tick_params(bottom=True, labelbottom=True)
                        ax.set_xlabel('[ms]')
                        ax.hlines(0,
                                  0,
                                  8,
                                  clip_on=False,
                                  linewidth=6,
                                  color='r')
                        ax.annotate('Stimulus', (2.3, -.6),
                                    color='r',
                                    fontsize=15,
                                    annotation_clip=False)

                if generate_plots:
                    f = (f'{const.P["outputPath"]}/{plots_dest_dir_appdx}/'
                         f'onset_offset_{key}_{which_region}.png')
                    fig.savefig(f)
                    plt.close()

    on_off_scores = pd.concat(all_spike_bins, axis=1).T
    if csv_dest_dir_appdx:
        f = f'{const.P["outputPath"]}/{csv_dest_dir_appdx}/onset_offset_spikebins_{which_region}.csv'
        on_off_scores.to_csv(f)
    print(f'\nSpikes found: {nspikes:,}')
    return on_off_scores
示例#4
0
def ssa_correlation(dest_dir_appdx,
                    which='O10',
                    start=5,
                    stop=20,
                    post_stim=False):
    """Plot the correlation of SSA indices. Each mouse has two SIs, one for C1,
    one for C2. With 4 mice, we get a vector of 8 datapoints that are then 
    correlated in different regions. If one of those 8 is NA, it is excluded 
    from the correlation. At least 4 values must be present for the correlation 
    of two regions. A heatmap is drawn showing pearson's rho coefficient. To
    check if the correlation looks reasonable, for each combination of regions,
    a plot is drawn that shows the SIs and a fitted regression line. The 
    paradigm is passed in `which`. `start` and `stop` also work as mentioned 
    before. `post_stim` may simply be passed as True to include late responses 
    that are hard fixed in the code to late_start=100 and late_stop=200. This 
    essentially extends the regions domain from G, IG .. to G, IG, ... late_G, 
    late IG.... Often for late responses the cut off response constant 
    should be lowered. Therefore, post_stim can also be a number that will 
    be subtracted from the value defined in MUA_constants.py SI_MIN_FRATE_5MS. 
    Plots are saved at P["MUA_ouput"]/dest_dir_appdx/*
    """
    data = fetch(
        mouseids=['mGE82', 'mGE83', 'mGE84', 'mGE85'],
        paradigms=[which + 'C1', which + 'C2', 'MS']
        if not which == 'MS' else ['O25C1', 'O25C2', 'MS'],
        stim_types=['Deviant', 'Predeviant', 'UniquePredeviant', 'C1', 'C2'],
        collapse_ctx_chnls=True,
        collapse_th_chnls=True,
        drop_not_assigned_chnls=True)

    SIs, frates, _ = compute_si(data, MS=which == 'MS', start=start, stop=stop)

    if post_stim:
        late_start = 100
        late_stop = 200
        print(const.SI_MIN_FRATE_5MS)
        if post_stim is not True:
            const.SI_MIN_FRATE_5MS -= post_stim
        print(const.SI_MIN_FRATE_5MS)
        SIs_post, frates_post, _ = compute_si(data,
                                              MS=which == 'MS',
                                              start=late_start,
                                              stop=late_stop)
        SIs_post.columns = pd.MultiIndex.from_tuples([
            (region_whisker[0] + '_lateSI', region_whisker[1])
            for region_whisker in SIs_post.columns
        ])
        frates_post.columns = pd.MultiIndex.from_tuples([
            (region_whisker[0] + '_lateSI', region_whisker[1])
            for region_whisker in frates_post.columns
        ])
        order = SIs.columns.unique(0).tolist() + SIs_post.columns.unique(
            0).tolist()

        SIs = pd.concat([SIs, SIs_post], axis=1)
        frates = pd.concat([frates, frates_post], axis=1)

        SIs = SIs.stack(level=1).reindex(order, axis=1)
        frates = frates.stack(level=1).reindex(order, axis=1)
    else:
        SIs = SIs.stack(level=1).reindex(['VPM', 'G', 'SG', 'IG', 'dIG'],
                                         axis=1)
        frates = frates.stack(level=1).reindex(['VPM', 'G', 'SG', 'IG', 'dIG'],
                                               axis=1)

    p_values = {}
    late = 'late' if post_stim else ''
    for comp_reg, comp_dat in SIs.iteritems():
        for i, (reg, region_dat) in enumerate(SIs.iteritems()):
            if reg == comp_reg:
                continue
            fig, ax = plt.subplots(figsize=(6, 6))

            [sp.set_visible(False) for sp in ax.spines.values()]
            ax.patch.set_facecolor('grey')
            ax.patch.set_alpha(.16)
            ax.hlines((0), -1, 1, color='black', linewidth=.5)
            ax.vlines((0), -1, 1, color='black', linewidth=.5)

            ax.set_xlim(-.75, 1.05)
            ax.set_ylim(-.3, 1.05)

            ax.set_xlabel('SSA index ' + const.REGIONS_EXT[comp_reg])
            ax.set_ylabel('SSA index ' + const.REGIONS_EXT[reg])

            ax.scatter(comp_dat, region_dat, s=5, color='k')
            [
                ax.annotate('-'.join(idx), (comp_dat[idx], region_dat[idx]),
                            size=7,
                            ha='right',
                            va='bottom' if 'C1' in idx else 'top')
                for idx in frates[comp_reg].index
            ]

            notna = comp_dat.notna().values & region_dat.notna()
            if notna.sum() <= 4:
                p_values[f'{comp_reg}-{reg}'] = 'NaN'
            else:
                r = ss.linregress(comp_dat[notna], region_dat[notna])
                ax.plot((-1, 0, 1), (r.intercept - r.slope, r.intercept,
                                     r.slope + r.intercept),
                        color=const.REGION_CMAP[reg],
                        label=f'{reg} p-value: {r.pvalue:.2f}')
                p_values[f'{comp_reg}-{reg}'] = r.pvalue
                plt.legend(loc='lower left')

            f = f'{const.P["outputPath"]}/{dest_dir_appdx}/SSA_corr_{comp_reg}-{reg}_{which}_{start}_{stop}ms_{late}.{const.PLOT_FORMAT}'
            fig.savefig(f)

    print(SIs.to_string())
    SIs.T[SIs.notna().sum() <= 4] = np.nan
    print(SIs.to_string())
    corr = SIs.corr()

    figsize = (8, 8) if not post_stim else (11, 11)
    fig, ax = plt.subplots(figsize=figsize)
    fig.subplots_adjust(left=.1, bottom=.1, top=.75, right=.75)
    im = ax.imshow(corr, aspect='auto', vmin=-1, vmax=1, cmap='RdBu_r')

    for row, reg in enumerate(SIs.columns):
        for col, reg_nd in enumerate(SIs.columns):
            if reg == reg_nd:
                continue
            pval = p_values[f'{reg}-{reg_nd}']
            pval = f'p={pval:.3f}' if type(pval) is not str else 'NaN'
            ax.annotate(pval, (row - .35, col), fontsize=8)

    # colorbar and legend
    at = (
        0.77,
        .95,
        .2,
        .012,
    )
    cb = fig.colorbar(im, cax=fig.add_axes(at), orientation='horizontal')
    cb.set_label('Pearson\'s r', size=12)
    cb.ax.get_xaxis().set_label_position('top')

    ax.set_title(f'SSA index correlation {which} {start}-{stop}ms')
    ax.set_xticks(np.arange(SIs.shape[1]))
    ax.set_xticklabels(SIs.columns, fontsize=10, rotation=45)
    ax.set_yticks(np.arange(SIs.shape[1]))
    ax.set_yticklabels(SIs.columns,
                       fontsize=10,
                       rotation=45,
                       rotation_mode='anchor')

    n_smples = [
        f'{reg} n={nsmples}' for reg, nsmples in SIs.notna().sum().iteritems()
    ]
    ax.annotate('\n'.join(n_smples), (.77, .6),
                annotation_clip=False,
                xycoords='figure fraction')

    f = f'{const.P["outputPath"]}/{dest_dir_appdx}/SSA_corr_heatmap_{which}_{start}_{stop}ms_{late}.{const.PLOT_FORMAT}'
    fig.savefig(f)
示例#5
0
def oddball_si(dest_dir_appdx,
               which='O10',
               compare_with_MS=False,
               start=5,
               stop=20):
    """ Plot the SSA index of a specific paradigm. The paradigm is passed to 
    `which`, it may be `O10`, `O25` or `O25U`. By default the SSA index is 
    calculated by comparing the devient response to the predevient (standard). 
    The response is defined as the average firing rate in a given time interval
    post stimulus. This interval is passed by `start` (default 5) and `stop` 
    (default 20). The SSA index may also be calculated using the many standards
    paradigm. Instead of comparing with the standard presentation of the whisker,
    the stimulation of the whisker within the MS paradigm is used. This is done
    by passing `compare_with_MS` as True. Details on the compuation of SSA 
    indecis can be found in the doctring of the MUA_utility.py function 
    `compute_si()`. Besides the main plot, a histogram of respones is saved to
    check the general magnitute of responses. All plots are saved at 
    P["outputPath"]/dest_dir_appdx/*.
    """
    versusMS = 'versusMS' if compare_with_MS else ''
    data = fetch(
        mouseids=['mGE82', 'mGE83', 'mGE84', 'mGE85'],
        paradigms=[which + 'C1', which + 'C2', 'MS'],
        stim_types=['Deviant', 'Predeviant', 'UniquePredeviant', 'C1', 'C2'],
        collapse_ctx_chnls=True,
        collapse_th_chnls=True,
        drop_not_assigned_chnls=True)

    SIs, _, SI_raw_values = compute_si(data,
                                       MS=compare_with_MS,
                                       start=start,
                                       stop=stop)
    SIs_mean = SIs.mean()

    fig, (ax_left, ax_right) = plt.subplots(ncols=2, figsize=(7, 5))
    [
        ax_left.hist(SI_raw_values[m_id].values.flatten(),
                     label=m_id,
                     alpha=.8,
                     bins=20,
                     range=(0, 10),
                     color=const.GENERAL_CMAP[m_id])
        for m_id in SI_raw_values.columns.unique(0)
    ]
    [ax_right.hist(SI_raw_values.values.flatten(), bins=40)]
    ax_left.set_ylim(ax_left.get_ylim())
    ax_left.set_ylabel('counts')
    ax_left.set_xlabel(f'avg. firingrate {start}-{stop} ms')
    ax_right.set_xlabel(f'avg. firingrate {start}-{stop} ms')
    ax_left.set_title('low responses')
    ax_right.set_title('all responses')
    ax_left.legend()
    ax_left.vlines(const.SI_MIN_FRATE_5MS, -5, 50, linestyle='dashed')
    ax_left.annotate('cut off',
                     (const.SI_MIN_FRATE_5MS + .1, ax_left.get_ylim()[1] * .9))
    fig.suptitle(f'SSA {which} {versusMS} {start}-{stop} ms')

    versusMS = 'versusMS' if compare_with_MS else ''
    f = f'{const.P["outputPath"]}/{dest_dir_appdx}/responses_hist_{which}_{versusMS}_{start}-{stop}ms.{const.PLOT_FORMAT}'
    fig.savefig(f)

    fig, ax = plt.subplots(figsize=(8, 6))
    fig.subplots_adjust(top=.75, right=.82, left=.2, bottom=.15)

    [sp.set_visible(False) for sp in ax.spines.values()]
    ax.spines['left'].set_visible(True)
    ax.patch.set_facecolor('grey')
    ax.patch.set_alpha(.16)
    ax.hlines((0), 0, 23, color='black', linewidth=.5)
    ax.set_title(f'SSA {which} {versusMS} {start}-{stop} ms', pad=65)

    xt = [1, 2, 6, 7, 11, 12, 16, 17, 21, 22]
    xt_mid = [1.5, 6.5, 11.5, 16.5, 21.5]
    ax.set_xlim((0, 23))
    ax.set_xticks(xt)
    ax.set_xticklabels(['C2' if i % 2 else 'C1' for i in range(len(xt))])

    ax.yaxis.grid(True, which='major')
    ax.set_ylim((-1.05, 1.05))
    ax.set_ylabel('SSA index (SI)')
    ax.set_yticks(np.arange(-1, 1.001, .25))

    for m_id, mouse_si in SIs.iterrows():
        ax.scatter(xt,
                   mouse_si,
                   color=const.GENERAL_CMAP[m_id],
                   s=8,
                   alpha=.9,
                   label=m_id)

    regions = [const.REGIONS_EXT[reg] for reg in SIs_mean.index.unique(0)]
    [
        ax.annotate(reg, (x_m, 1.05), rotation=30)
        for reg, x_m in zip(regions, xt_mid)
    ]
    ax.scatter(xt,
               SIs_mean,
               color='k',
               s=20,
               alpha=.7,
               marker='D',
               label='Average')
    ax.legend(bbox_to_anchor=(1.001, 1.001), loc='upper left')

    lbl = f'avg. firingr. in 5ms < {const.SI_MIN_FRATE_5MS}\n# excluded mice:'
    ax.annotate(lbl, (-7.3, -1.5), annotation_clip=False, fontsize=9)
    [
        ax.annotate(n, (x_t - .3, -1.5), annotation_clip=False, fontsize=9)
        for n, x_t in zip(SIs.isna().sum().values, xt)
    ]
    SIs.isna().sum()

    versusMS = 'versusMS' if compare_with_MS else ''
    f = f'{const.P["outputPath"]}/{dest_dir_appdx}/SSA_index_{which}_{versusMS}_{start}-{stop}ms.{const.PLOT_FORMAT}'
    fig.savefig(f)
示例#6
0
def firingrate_noise_timeline(dest_dir_appdx='../',
                              fname_postfix='',
                              subtr_noise=False):
    """Check background firingrates activity for experimental time line between 
    4 different mice and averaged stimulus types. `subtr_noise` should either 
    be False or `deviant_alone`, `paradigm_wise` (ie the method)."""

    data = fetch()

    ratio = {
        'width_ratios': [.1] * 11 + [.28],
    }
    fig, axes = plt.subplots(4,
                             12,
                             sharex=False,
                             sharey=False,
                             figsize=(15, 13),
                             gridspec_kw=ratio)
    fig.subplots_adjust(hspace=.2,
                        wspace=.03,
                        right=.95,
                        top=.86,
                        left=.1,
                        bottom=.07)

    title = 'noise over experimental paradigm timeline\n(mean over stimulus types)'
    if subtr_noise:
        title += '\n\nNoise subtracted: ' + subtr_noise
    fig.suptitle(title, size=14)
    plt.cm.get_cmap('gnuplot').set_gamma(.8)
    [
        ax.tick_params(bottom=False,
                       left=False,
                       labelbottom=False,
                       labelleft=False) for ax in axes.flatten()
    ]

    for mouse, i in zip(const.ALL_MICE, range(4)):
        mouse_dat = slice_data(data, [mouse],
                               firingrate=True,
                               frate_noise_subtraction=subtr_noise)
        axes[i, 0].set_ylabel(mouse + '\nchannels',
                              size=12,
                              rotation=0,
                              ha='right',
                              va='center')

        neg_frates = []
        for parad, j in zip(const.PARAD_ORDER[mouse], range(11)):
            parad_frates = [
                df for key, df in mouse_dat.items() if parad in key
            ]

            neg_frates_counts = [(fr < 0).sum().sum() for fr in parad_frates]
            neg_frates.append(sum(neg_frates_counts) / len(neg_frates_counts))

            frates = (sum(parad_frates) / len(parad_frates)).astype(int)
            im = axes[i, j].imshow(frates,
                                   cmap='gnuplot',
                                   aspect='auto',
                                   extent=[-52.5, 202.5, -.5, 31.5],
                                   vmin=0,
                                   vmax=500)

            axes[i, j].set_title(parad, size=7, pad=2)
            if 'DA' in parad and subtr_noise == 'deviant_alone':
                axes[i, j].set_title('**' + parad + '**', size=9, pad=2)
                [
                    axes[i, j].spines[where].set_color('yellow')
                    for where in axes[i, j].spines
                ]
                [
                    axes[i, j].spines[where].set_linewidth(1.5)
                    for where in axes[i, j].spines
                ]

            if (i == 0) and (j == 0):
                axes[i, j].set_xlim((-52.5, 202.5))
                axes[i, j].set_xticks([-50, 0, 80, 160])

                # colorbar and legend
                at = (
                    0.77,
                    .95,
                    .2,
                    .012,
                )
                cb = fig.colorbar(im,
                                  cax=fig.add_axes(at),
                                  orientation='horizontal')
                cb.set_label('Mean Firing Rate in 5ms frame', size=12)
                cb.ax.get_xaxis().set_label_position('top')

        rel_neg_frates = np.array(neg_frates) / 1600
        axes[i, 11].bar(range(11), rel_neg_frates)
        axes[i, 11].bar([12], rel_neg_frates.mean())

        # x axis
        axes[i, 11].tick_params(labelbottom=True, rotation=35, labelsize=6.5)
        axes[i, 11].set_xticks(list(range(11)) + [12])
        axes[i, 11].set_xticklabels(const.PARAD_ORDER[mouse] + ('Average', ),
                                    clip_on=False,
                                    ha='right',
                                    y=.04)

        # y axis
        axes[i, 11].tick_params(right=True, labelright=True)
        axes[i, 11].set_ylim((0, 1))
        yticks = np.arange(0, 1.1, .1)
        axes[i, 11].set_yticks(yticks)
        axes[i, 11].set_yticklabels(
            [str(yt) if yt in (0, 1) else '' for yt in yticks])
        axes[i, 11].yaxis.set_label_position("right")
        axes[i, 11].yaxis.grid(True, alpha=.5)
        axes[i, 11].set_axisbelow(True)

        if i == 1:
            axes[i, 11].set_ylabel(
                'Proportion of negative firing rates (of 32 channels x 50 time bins)',
                size=10)

    plt.savefig(
        f'{const.P["outputPath"]}/{dest_dir_appdx}/firingrate_over_time{fname_postfix}.{const.PLOT_FORMAT}'
    )
示例#7
0
    def plot_paradigm(parad):
        print(parad)
        if 'C1' in parad or 'C2' in parad:
            dev = parad[-2:]
            std = 'C2' if dev == 'C1' else 'C1'

        to_regions = [
            'collapse_ctx_chnls', 'collapse_th_chnls',
            'drop_not_assigned_chnls'
        ]
        to_regions = dict.fromkeys(to_regions,
                                   True if chnls_to_regions else False)

        if grouping == 'paradigm_wise':
            data = fetch(paradigms=[parad])
        elif grouping == 'whisker_wise':
            if parad != 'MS':
                dev_data = fetch(paradigms=[parad],
                                 stim_types=['Deviant'],
                                 **to_regions)

                std_parad = parad.replace(dev, std)
                stim_types = ['Predeviant', 'UniquePredeviant']
                if all_stimuli:
                    stim_types.extend(
                        ['Standard', 'Postdeviant', 'UniquePostdeviant'])
                std_data = fetch(paradigms=[std_parad],
                                 stim_types=stim_types,
                                 **to_regions)
                if std_parad in ['DOC1', 'DOC2']:
                    std_data = fetch(paradigms=[std_parad],
                                     stim_types=['Standard'],
                                     **to_regions)

                data = {**std_data, **dev_data}
            else:
                data = fetch(paradigms=[parad], **to_regions)

        elif grouping == 'whisker_wise_reduced':
            dev_parads = [
                this_parad for this_parad in const.ALL_PARADIGMS
                if dev in this_parad
            ]
            std_parads = [
                this_parad.replace(dev, std) for this_parad in dev_parads
            ]
            if dev == 'C1':
                order = (
                    'DAC1-Deviant',
                    'O10C2-Predeviant',
                    'O10C1-Deviant',
                    'O25C2-UniquePredeviant',
                    'O25C1-Deviant',
                    'MS-C1',
                    'DOC2-Standard',
                    'DOC1-Deviant',
                )
            else:
                order = ('DAC2-Deviant', 'O10C1-Predeviant', 'O10C2-Deviant',
                         'O25C1-UniquePredeviant', 'O25C2-Deviant', 'MS-C2',
                         'DOC1-Standard', 'DOC2-Deviant')

            dev_data = fetch(paradigms=dev_parads, stim_types=['Deviant'])
            std_data = fetch(paradigms=std_parads,
                             stim_types=['Predeviant', 'UniquePredeviant'])
            std_data.update(
                fetch(paradigms=['DO' + std], stim_types=['Standard']))
            ms_data = fetch(paradigms=['MS'], stim_types=[dev])
            data = {**std_data, **dev_data, **ms_data}
            data = OrderedDict({
                key: data[key]
                for ord_key in order for key in data.keys() if ord_key in key
            })

        if grouping != 'whisker_wise_reduced':
            if parad != 'MS' and not all_stimuli:
                args = {'ncols': 2}
                width = 7
            else:
                args = {'ncols': 4}
                width = 13

        else:
            args = {
                'ncols': 8 + 4,
                'gridspec_kw': {
                    'width_ratios':
                    [.1, .015, .1, .1, .015, .1, .1, .015, .1, .015, .1, .1],
                }
            }
            width = 20
        fig, axes = plt.subplots(4,
                                 **args,
                                 sharex=True,
                                 sharey=True,
                                 figsize=(width, 13))
        fig.subplots_adjust(hspace=.06,
                            wspace=.03,
                            right=.98,
                            top=.86,
                            left=.13,
                            bottom=.07)

        [
            ax.tick_params(bottom=False,
                           left=False,
                           labelbottom=False,
                           labelleft=False) for ax in axes.flatten()
        ]
        if grouping != 'whisker_wise_reduced':
            title = const.PARAD_FULL[
                parad] + '- mean firing rates across 4 mice'
        else:
            title = parad + '- mean firing rates across 4 mice'

        if subtr_noise:
            title += '\nNOISE SUBTRACTED'
        fig.suptitle(title, size=14)
        plt.cm.get_cmap('gnuplot').set_gamma(.8)

        print()
        print()
        print()
        for mouse, i in zip(const.ALL_MICE, range(4)):
            mouse_dat = slice_data(data, [mouse],
                                   firingrate=True,
                                   frate_noise_subtraction=subtr_noise)
            if chnls_to_regions:
                axes[i, 0].tick_params(left=True, labelleft=True)
                lbls = list(mouse_dat.values())[0].index[::-1]
                axes[i, 0].set_yticks(np.arange(6.4 / 2, 32, 6.4))
                axes[i, 0].set_yticklabels(lbls)

            axes[i, 0].set_ylabel(mouse + '\nchannels',
                                  size=12,
                                  rotation=0,
                                  ha='right',
                                  va='center',
                                  x=-50)

            which_ax = 0
            for (key, frates), j in zip(mouse_dat.items(),
                                        range(args['ncols'])):

                im = axes[i, which_ax].imshow(frates,
                                              cmap='gnuplot',
                                              aspect='auto',
                                              extent=[-52.5, 202.5, -.5, 31.5],
                                              vmin=0,
                                              vmax=500)
                axes[i, which_ax].vlines(0,
                                         -.5,
                                         31.5,
                                         color='w',
                                         alpha=.6,
                                         linewidth=1)
                if i == 0:
                    stim_t = key[key.rfind('-') + 1:]
                    col = 'k' if stim_t != 'Deviant' else const.COLORS[
                        'deep_red']
                    if grouping == 'paradigm_wise' or parad == 'MS':
                        axes[i, which_ax].set_title(stim_t, color=col)
                    elif grouping == 'whisker_wise':
                        axes[i, which_ax].set_title(f'{parad[-2:]} {stim_t}',
                                                    color=col)
                    elif grouping == 'whisker_wise_reduced':
                        pard_full = const.PARAD_FULL[
                            key[key.find('-') + 1:key.rfind('-')]][:-3]
                        title = f'{dev} {stim_t}\n{pard_full}'
                        if 'MS' not in key:
                            axes[i, which_ax].set_title(title, color=col)
                        else:
                            axes[i, which_ax].set_title(stim_t +
                                                        '\nMany Standards',
                                                        color=col)

                elif i == 3:
                    axes[i, which_ax].tick_params(bottom=True,
                                                  labelbottom=True)
                    axes[i, which_ax].set_xlabel('ms')
                if (i == 0) and (which_ax == 0):
                    axes[i, which_ax].set_xlim((-52.5, 202.5))
                    axes[i, which_ax].set_xticks([-50, 0, 80, 160])

                    # colorbar and legend
                    at = (
                        .58,
                        .9,
                        2.5 / width,
                        .012,
                    )
                    cb = fig.colorbar(im,
                                      cax=fig.add_axes(at),
                                      orientation='horizontal')
                    cb.set_label('Mean Firing Rate in 5ms frame', size=12)
                    cb.ax.get_xaxis().set_label_position('top')

                which_ax += 1
                if grouping == 'whisker_wise_reduced' and which_ax in [
                        1, 4, 7, 9
                ]:
                    axes[i, which_ax].set_visible(False)
                    which_ax += 1
        return fig