def thalamic_mapping_panel(dest_dir_appdx='../',
                           which_paradigms=['DAC1', 'DAC2'],
                           anatomy_dir=None,
                           ts_plots_dir=None):
    """
    This function creates a panel similar in style to the one produced in 
    cortical_mapping_panel(). The panel has the LFP respones plotted for the
    thalamic channels, firingrate plots, shows the injection slice, and the 
    timestamp of first response. Again, which_paradigms controls the paradigms
    being plotted, by default DAC1, DAC2. Both `anatomy_dir` and `ts_plots_dir`
    need to be passed for the panel to include these. Note that both need to be
    passed, even if you're note interested in one of them in order for them to 
    be drawn. The `anatomy_dir` should have files of the form `mouse.png`, eg.
    `mGE82.png`, ts_plots_dir is simply the output dir of the previous function,
    first_response(). Plots are saved as usual at P["outputPath]/dest_dir_appdx.
    Note that the function expects that cortical mapping to be done. The same 
    file as the one used for cortical mapping is referenced: 
    P["outputPath"]/../chnls_map.csv. Subsequent scripts expect `VPM` as the 
    label of this thalamic mapping.
    """
    def make_plot(lfp, frates, lfp_summ, mua_summ, mouse, parad, stim_t):
        thal_lfp = lfp.iloc[-10:, :-50] * 1_000_000
        ctx_lfp = lfp.iloc[:4, :-50] * 1_000_000
        frates = frates.loc[range(22, 32), :]
        x_time = lfp.columns[:-50].values.astype(float)

        # init figure
        fig = plt.figure(figsize=(14, 10.2))
        gs = fig.add_gridspec(
            7,
            3,
            width_ratios=[.5, .35, .15],
            height_ratios=[.25, .025, .15, .15, .15, .15, .15],
            hspace=0,
            wspace=.12,
            right=.975,
            top=.95,
            left=.02,
            bottom=.05)
        lfp_ax_ctx = fig.add_subplot(gs[0, 0])
        frate_ax = fig.add_subplot(gs[0, 1])
        assig_ax = fig.add_subplot(gs[0, 2])
        lfp_axl = [fig.add_subplot(gs[i, 0]) for i in range(2, 7)]
        lfp_axr = [fig.add_subplot(gs[i, 1:]) for i in range(2, 7)]
        all_axs = lfp_axl + lfp_axr + [lfp_ax_ctx, frate_ax, assig_ax]
        lfp_axs = lfp_axl + lfp_axr + [
            lfp_ax_ctx
        ]  # 0,1,2,3,4 left plots, 5,6,7,8,9 right plots, 10 ctx plot
        # clean axes
        [
            ax.tick_params(left=False,
                           bottom=False,
                           labelleft=False,
                           labelbottom=False) for ax in all_axs
        ]

        # iterate bottom plots block and top left (all lfp's), setup axis
        for which_ax, ax in enumerate(lfp_axs):
            # general
            [sp.set_visible(False) for sp in ax.spines.values()]
            ax.patch.set_facecolor('grey')
            ax.patch.set_alpha(.16)
            ax.hlines(0, -5, 20, color='grey')

            # x axis
            xleft, xright = -0.05, 0.2
            ax.xaxis.grid(True, which='major')
            ax.set_xlim((xleft, xright))
            xticks = (xleft, 0, 0.05, 0.1, 0.15, xright)
            ax.set_xticks(xticks)
            # bottom plots
            if which_ax in (4, 9):
                ax.tick_params(labelbottom=True)
                ax.set_xticklabels(xticks, size=8)
                ax.set_xlabel('[s]', size=8)

            # y axis
            if which_ax != 10:
                ybot, ytop = -25, 25
            else:  # ctx plot
                ybot, ytop = -180, 90
            ax.yaxis.grid(True, which='major')
            ax.set_ylim((ybot, ytop))
            yts = np.concatenate(
                [np.arange(0, ybot, -15)[1:],
                 np.arange(0, ytop, 15)])
            ax.set_yticks(yts)
            # right plots
            if which_ax in (5, 6, 7, 8, 9):
                ax.tick_params(labelright=True)
                ax.set_yticklabels((-15, 0, 15), size=8)
            # ctx plots
            elif which_ax == 10:
                ax.tick_params(labelright=True)
                ax.set_yticklabels(
                    [yt if yt in (-150, -75, 0, 60) else None for yt in yts],
                    size=8)
            # ctx plot and middle, right plot
            if which_ax in (7, 10):
                ax.annotate('[uV]', (.263, 0), size=8, annotation_clip=False)
            if which_ax == 10:
                ax.set_title(f'{mouse}-{parad}-{stim_t}')

        # draw lfp's
        for (chnl, dat), ax in zip(thal_lfp.iterrows(), lfp_axl + lfp_axr):
            ax.set_ylabel(chnl + 1,
                          fontsize=10,
                          rotation=0,
                          ha='right',
                          va='center',
                          labelpad=2)
            ax.plot(x_time,
                    dat.values,
                    clip_on=False,
                    color=const.REGION_CMAP['Th'])
            # ax.vlines((lfp_summ.loc[chnl, 'Peak_neg_ts']/1000), ybot, ytop)
            # ax.vlines((mua_summ.loc[chnl, 'AvgTimetoFirstSpike']/1000), ybot, ytop, color='green')
        [
            lfp_ax_ctx.plot(x_time,
                            lfp,
                            label=region,
                            color=const.REGION_CMAP[region],
                            clip_on=False)
            for region, lfp in ctx_lfp.iterrows()
        ]
        lfp_ax_ctx.legend()

        # setup the assignment plot
        assig_ax.tick_params(left=True, labelleft=True, pad=9.5)
        assig_ax.set_xlim((0, 1))
        assig_ax.set_ylim((32.5, 22.5))
        assig_ax.set_yticks(np.arange(23, 33))

        # somewhat cryptic... this checks if the cortical map in the previous
        # paradigm differs from the current one. If True, the current plot
        # gets annotated with the change in cortical mapping
        first_chnls = pd.Series({
            mouse_parad: mapping.index[this_map == 'SG'][0]
            for mouse_parad, this_map in mapping.iteritems()
        })
        mouse_parad = f'{mouse}-{parad}'
        mouse_first_chnls = first_chnls[[
            entr for entr in first_chnls.index if mouse in entr
        ]]

        current_first_chnl = mouse_first_chnls.loc[mouse_parad]
        if parad != const.PARAD_ORDER[mouse][0]:  # if not the first paradigm
            before_first_chnl = mouse_first_chnls[
                mouse_first_chnls.index.get_loc(mouse_parad) - 1]
            moved_by = before_first_chnl - current_first_chnl
            if moved_by:
                up_dwn = 'up' if moved_by > 0 else 'down'
                ann = f'Cortical map moved\n{up_dwn} by {abs(moved_by)}!'
                assig_ax.annotate(ann, (.1, 22),
                                  annotation_clip=False,
                                  fontsize=13)

        th_mapping = mapping.iloc[-10:, :].loc[:, [mouse_parad]]
        cols = [const.REGION_CMAP[region[0]] for region in th_mapping.values]

        # define how each layer is colored, label
        assig_ax.barh(np.arange(23, 33),
                      1,
                      height=1,
                      edgecolor='k',
                      color=cols,
                      alpha=.6)
        [
            assig_ax.annotate(th_mapping.values[i][0], (.2, i + 23 + .2),
                              fontsize=11) for i in range(len(th_mapping))
        ]

        frate_ax.tick_params(top=True, labeltop=True, right=True)
        frate_ax.xaxis.set_label_position('top')
        frate_ax.set_xticks((-.05, 0, .1, .2))
        frate_ax.set_xticklabels((-.05, 0, .1, .2), size=8)
        frate_ax.set_xlabel('[ms]', size=8)

        frate_ax.set_yticks(np.arange(23, 33))
        frate_ax.set_ylim((22.5, 32.5))

        frate_ax.imshow(frates,
                        cmap='gnuplot',
                        aspect='auto',
                        extent=[-.0525, 0.2025, 22.5, 32.5],
                        vmin=0,
                        vmax=15)
        frate_ax.vlines([-.002], 22.5, 33.5, color='w', alpha=.6)
        # plt.show()
        return fig

    data = fetch(paradigms=which_paradigms,
                 collapse_ctx_chnls=True,
                 collapse_th_chnls=False,
                 drop_not_assigned_chnls=False)
    # get the current assignment
    mapping = pd.read_csv(f'{const.P["outputPath"]}/../chnls_map.csv',
                          index_col=0).reset_index(drop=True)

    for mouse in const.ALL_MICE:
        if mouse not in const.PARAD_ORDER:
            print(
                f'`{mouse}` not found in constant PARAD_ORDER (MUA_constants.py). Check.'
            )
            exit(1)
        for i, parad in enumerate(const.PARAD_ORDER[mouse]):
            if parad not in which_paradigms:
                continue

            if parad not in ['MS', 'DOC1', 'DOC2']:
                peak_stim = 'Deviant'
            else:
                peak_stim = 'C1' if parad == 'MS' else 'Standard'

            frates = slice_data(data, [mouse], [parad], [peak_stim],
                                firingrate=True,
                                frate_noise_subtraction='paradigm_wise',
                                drop_labels=True)[0]

            lfp = slice_data(data, [mouse], [parad], [peak_stim],
                             lfp=True,
                             drop_labels=True)[0]
            lfp_summ = slice_data(data, [mouse], [parad], [peak_stim],
                                  lfp_summary=True,
                                  drop_labels=True)[0]
            mua_summ = slice_data(data, [mouse], [parad], [peak_stim],
                                  mua_summary=True,
                                  drop_labels=True)[0]

            fig = make_plot(lfp, frates, lfp_summ, mua_summ, mouse, parad,
                            peak_stim)
            f = f'{const.P["outputPath"]}/{dest_dir_appdx}/{mouse}_{i+1}_{parad}.png'
            fig.savefig(f)

            if anatomy_dir and ts_plots_dir:
                plot = Image.open(f)
                anat = Image.open(f'{anatomy_dir}/{mouse}.png')
                ts_plot = Image.open(
                    f'{ts_plots_dir}/{mouse}_{parad}_{peak_stim}_first_ts.png')

                final_img = Image.new(
                    'RGB',
                    (plot.width + anat.width, anat.height + ts_plot.height),
                    color='white')
                final_img.paste(plot, (0, 0))  # upper left
                final_img.paste(anat, (plot.width, 0))  # upper right
                final_img.paste(ts_plot,
                                (plot.width, anat.height))  # upper right
                final_img.save(f)
def plot_first_response(dest_dir_appdx='../',
                        which_paradigms=['DAC1', 'DAC2']):
    """
    This function plots a comparison between first responses in the different 
    channels. This information is taken from MUA summary (AvgTimetoFirstSpike) 
    and LFP summary (`Peak_neg_ts`). These plots are used in the panel for 
    defining thalamic channels using the function below. `which_paradigms` 
    selects the paradigms to plot, defaults to [`DAC1`, `DAC2`]. If all paradigms
    should be plotted, you could pass const.ALL_PARADIGMS. The plots are saved 
    as usual at P["outputPath"]/dest_dir_appdx/
    """
    data = fetch(paradigms=which_paradigms,
                 collapse_ctx_chnls=True,
                 collapse_th_chnls=False,
                 drop_not_assigned_chnls=False)

    for mouse in const.ALL_MICE:
        if mouse not in const.PARAD_ORDER:
            print(
                f'`{mouse}` not found in constant PARAD_ORDER (MUA_constants.py). Check.'
            )
            exit(1)
        for i, parad in enumerate(const.PARAD_ORDER[mouse]):
            if parad not in which_paradigms:
                continue

            if parad not in ['MS', 'DOC1', 'DOC2']:
                peak_stim = 'Deviant'
            else:
                peak_stim = 'C1' if parad == 'MS' else 'Standard'
            fig, axes = plt.subplots(nrows=2, figsize=(10, 3.4))
            fig.subplots_adjust(right=.975,
                                top=.94,
                                left=.02,
                                bottom=.13,
                                hspace=.15)

            for which_ax, ax in enumerate(axes):
                # get the data (lfp or mua time stamp)
                if which_ax == 0:
                    lfp_summ = slice_data(data, [mouse], [parad], [peak_stim],
                                          lfp_summary=True,
                                          drop_labels=True)[0]
                    first_ts = lfp_summ.loc[:, 'Peak_neg_ts'].iloc[
                        -10:].sort_values()
                    first_ts_G = lfp_summ.loc['G', 'Peak_neg_ts']
                elif which_ax == 1:
                    mua_summ = slice_data(data, [mouse], [parad], [peak_stim],
                                          mua_summary=True,
                                          drop_labels=True)[0]
                    first_ts = mua_summ.loc[:, 'AvgTimetoFirstSpike'].iloc[
                        -10:].sort_values()
                    first_ts_G = mua_summ.loc['G', 'AvgTimetoFirstSpike']

                ax.tick_params(left=False, labelleft=False)
                if which_ax == 0:
                    ax.tick_params(bottom=False, labelbottom=False)
                [sp.set_visible(False) for sp in ax.spines.values()]
                ax.patch.set_facecolor('grey')
                ax.patch.set_alpha(.16)
                ax.hlines((0), -10, 200, color='grey')

                tit = 'Peak negative LFP time stamp' if which_ax == 0 else 'Avg first spike time stamp'
                nan_chnls = [
                    chnl + 1 for chnl in first_ts.index[first_ts.isna()]
                ]
                if nan_chnls:
                    tit = f'{tit}      (NA channels: {nan_chnls})'
                if which_ax == 1:
                    zero_chnls = (first_ts.index[first_ts == 0]).values
                    first_ts[zero_chnls] = np.nan
                    tit = f'{tit}      (no spike channels: {list(zero_chnls+1)})'
                ax.set_title(tit, loc='left', pad=2, size=9)

                ax.set_ylim((-1.5, 1.5))
                ax.set_xlim((-10, 200))
                xts = (0, 4, 8, 12, 16, 20, 30, 40, 60, 80, 100, 150, 200)
                ax.set_xticks(xts)
                ax.xaxis.grid(True, which='major')
                if which_ax == 1:
                    ax.set_xlabel('[ms] post stimulus', labelpad=2, size=9)

                step = .4
                ycoords = []
                close_coords = []
                for i in range(len(first_ts)):
                    if i and first_ts.iloc[i] - first_ts.iloc[i - 1] < 5:
                        step *= -1
                        if step in close_coords:
                            step = step + .4 if step > 0 else step - .4
                        if len(close_coords) == 0:
                            close_coords.append(step * -1)
                        close_coords.append(step)
                    else:
                        step = .4
                        close_coords = []
                    ycoords.append(step)

                [
                    ax.vlines(x,
                              0,
                              y,
                              linewidth=1,
                              color=const.GENERAL_CMAP['Th'],
                              alpha=.7) for y, x in zip(ycoords, first_ts)
                ]
                [
                    ax.annotate(chnl + 1, (x, y),
                                va='center',
                                ha='center',
                                zorder=20)
                    for y, (chnl, x) in zip(ycoords, first_ts.iteritems())
                ]
                ax.vlines(first_ts_G,
                          -.4,
                          .4,
                          linewidth=1,
                          color=const.GENERAL_CMAP['G'],
                          alpha=.7)
                ax.annotate('G', (first_ts_G, 0),
                            va='center',
                            ha='center',
                            zorder=20,
                            size=12)
            f = f'{const.P["outputPath"]}/{dest_dir_appdx}/{mouse}_{parad}_{peak_stim}_first_ts.png'
            fig.savefig(f)
Example #3
0
def firingrate_noise_timeline(dest_dir_appdx='../',
                              fname_postfix='',
                              subtr_noise=False):
    """Check background firingrates activity for experimental time line between 
    4 different mice and averaged stimulus types. `subtr_noise` should either 
    be False or `deviant_alone`, `paradigm_wise` (ie the method)."""

    data = fetch()

    ratio = {
        'width_ratios': [.1] * 11 + [.28],
    }
    fig, axes = plt.subplots(4,
                             12,
                             sharex=False,
                             sharey=False,
                             figsize=(15, 13),
                             gridspec_kw=ratio)
    fig.subplots_adjust(hspace=.2,
                        wspace=.03,
                        right=.95,
                        top=.86,
                        left=.1,
                        bottom=.07)

    title = 'noise over experimental paradigm timeline\n(mean over stimulus types)'
    if subtr_noise:
        title += '\n\nNoise subtracted: ' + subtr_noise
    fig.suptitle(title, size=14)
    plt.cm.get_cmap('gnuplot').set_gamma(.8)
    [
        ax.tick_params(bottom=False,
                       left=False,
                       labelbottom=False,
                       labelleft=False) for ax in axes.flatten()
    ]

    for mouse, i in zip(const.ALL_MICE, range(4)):
        mouse_dat = slice_data(data, [mouse],
                               firingrate=True,
                               frate_noise_subtraction=subtr_noise)
        axes[i, 0].set_ylabel(mouse + '\nchannels',
                              size=12,
                              rotation=0,
                              ha='right',
                              va='center')

        neg_frates = []
        for parad, j in zip(const.PARAD_ORDER[mouse], range(11)):
            parad_frates = [
                df for key, df in mouse_dat.items() if parad in key
            ]

            neg_frates_counts = [(fr < 0).sum().sum() for fr in parad_frates]
            neg_frates.append(sum(neg_frates_counts) / len(neg_frates_counts))

            frates = (sum(parad_frates) / len(parad_frates)).astype(int)
            im = axes[i, j].imshow(frates,
                                   cmap='gnuplot',
                                   aspect='auto',
                                   extent=[-52.5, 202.5, -.5, 31.5],
                                   vmin=0,
                                   vmax=500)

            axes[i, j].set_title(parad, size=7, pad=2)
            if 'DA' in parad and subtr_noise == 'deviant_alone':
                axes[i, j].set_title('**' + parad + '**', size=9, pad=2)
                [
                    axes[i, j].spines[where].set_color('yellow')
                    for where in axes[i, j].spines
                ]
                [
                    axes[i, j].spines[where].set_linewidth(1.5)
                    for where in axes[i, j].spines
                ]

            if (i == 0) and (j == 0):
                axes[i, j].set_xlim((-52.5, 202.5))
                axes[i, j].set_xticks([-50, 0, 80, 160])

                # colorbar and legend
                at = (
                    0.77,
                    .95,
                    .2,
                    .012,
                )
                cb = fig.colorbar(im,
                                  cax=fig.add_axes(at),
                                  orientation='horizontal')
                cb.set_label('Mean Firing Rate in 5ms frame', size=12)
                cb.ax.get_xaxis().set_label_position('top')

        rel_neg_frates = np.array(neg_frates) / 1600
        axes[i, 11].bar(range(11), rel_neg_frates)
        axes[i, 11].bar([12], rel_neg_frates.mean())

        # x axis
        axes[i, 11].tick_params(labelbottom=True, rotation=35, labelsize=6.5)
        axes[i, 11].set_xticks(list(range(11)) + [12])
        axes[i, 11].set_xticklabels(const.PARAD_ORDER[mouse] + ('Average', ),
                                    clip_on=False,
                                    ha='right',
                                    y=.04)

        # y axis
        axes[i, 11].tick_params(right=True, labelright=True)
        axes[i, 11].set_ylim((0, 1))
        yticks = np.arange(0, 1.1, .1)
        axes[i, 11].set_yticks(yticks)
        axes[i, 11].set_yticklabels(
            [str(yt) if yt in (0, 1) else '' for yt in yticks])
        axes[i, 11].yaxis.set_label_position("right")
        axes[i, 11].yaxis.grid(True, alpha=.5)
        axes[i, 11].set_axisbelow(True)

        if i == 1:
            axes[i, 11].set_ylabel(
                'Proportion of negative firing rates (of 32 channels x 50 time bins)',
                size=10)

    plt.savefig(
        f'{const.P["outputPath"]}/{dest_dir_appdx}/firingrate_over_time{fname_postfix}.{const.PLOT_FORMAT}'
    )
Example #4
0
def onset_offset_response(plots_dest_dir_appdx='',
                          csv_dest_dir_appdx='',
                          generate_plots=True,
                          single_channels=True,
                          draw_gmm_fit=True):
    """This function presents the start of the onset-offset analysis pipeline.
    This functions iterates all mice, all paradigms, all stimulus types as usual
    then fetches the dataframe with all the negative time stamps. Out of this,
    a histogram of the 0-20ms trace is generated (200 bins, 0.1ms timeframe 
    bins). The bin counts serve as the input feature for the classifier. The 
    final CSV with the computed bin counts will be saved at 
    P["outputPath"]/csv_dest_dir_appdx/onset_offset_spikebins_*_.csv'. If 
    `csv_dest_dir_appdx` is None, no data is saved. In addition, 
    `generate_plots` gives the option to draw the histograms.`draw_gmm_fit` will
    fit a Gaussian Mixture Model to the histogram with 10 components. When 
    `single_channels` is True (default), the 32 channels are not collpassed into
    the previously defined region mapping. The classifier is designed to 
    classify single channels, not collapsed regions. Setting this to False is 
    implemented for plotting region histograms. The plots are saved at 
    P["outputPath"]/plots_dest_dir_appdx/onset_offset_*_*_.png'. The produced
    histogram data is also returned besides being saved."""

    # get all the available data from the output dir
    if single_channels:
        data = fetch()
        which_region = 'channels'
        plt_spacers = {
            'hspace': 0,
            'right': .97,
            'top': .96,
            'left': .1,
            'bottom': .07
        }
    else:
        data = fetch(collapse_ctx_chnls=True,
                     collapse_th_chnls=True,
                     drop_not_assigned_chnls=True)
        which_region = 'regions'
        plt_spacers = {
            'hspace': 0,
            'right': .97,
            'top': .85,
            'left': .1,
            'bottom': .22
        }

    # due to different base level actitvity, set heatmap vmax seperately
    vmaxs = {
        'mGE82': 20,
        'mGE83': 40,
        'mGE84': 40,
        'mGE85': 40,
    }

    # iter the usual dimensions
    nspikes = 0
    all_spike_bins = []
    for m_id in const.ALL_MICE:
        for parad in const.ALL_PARADIGMS:
            for stim_t in const.ALL_STIMTYPES:
                key = '-'.join([m_id, parad, stim_t])
                # not all combinations of paradigm/stimtype exist
                if key not in data.keys():
                    continue

                # get the negative sptike time stamps (a MultiIndex DataFrame)
                spikes = slice_data(data,
                                    m_id,
                                    parad,
                                    stim_t,
                                    neg_spikes=True,
                                    drop_labels=True)[0].sort_index(axis=1)
                nspikes += np.count_nonzero(spikes)
                if not single_channels:
                    spikes = spikes.reindex(['SG', 'G', 'IG', 'dIG', 'VPM'],
                                            axis=1,
                                            level=0)

                if generate_plots:
                    # init plot with fitting size (height depends on n channels/ regions)
                    nregions = len(spikes.columns.unique(0))
                    fig, axes = plt.subplots(nrows=nregions,
                                             figsize=(6, .4 * nregions))
                    fig.subplots_adjust(**plt_spacers)
                    [
                        ax.tick_params(bottom=False,
                                       left=False,
                                       labelbottom=False,
                                       labelleft=False)
                        for ax in axes.flatten()
                    ]

                    axes[0].set_title(key)
                else:
                    # dummy
                    axes = range(len(spikes.columns.unique(0)))

                # # add empty channels as well so that the total samples is always consistent
                no_spike_chnls = [
                    chnl for chnl in range(1, 33)
                    if chnl not in spikes.columns.unique(0)
                ]
                for region in no_spike_chnls:
                    lbl = key + f'-{region:0>2d}' if single_channels else key + f'-{region}'
                    all_spike_bins.append(
                        pd.Series(np.zeros(200, dtype=int), name=lbl))

                for region, ax in zip(spikes.columns.unique(0), axes):
                    # get the spike timestamp data sliced to the channel/ region
                    region_spikes = spikes[region].values.flatten()
                    # set 0 to nan _> ignored by np.hist
                    region_spikes[region_spikes == 0] = np.nan
                    hist = np.histogram(region_spikes,
                                        bins=2000,
                                        range=(-50, 200))

                    # slice to 0-20ms bins
                    start, stop = 400, 600  # relative to 2000 bins from -50-200
                    spike_bins = hist[0][np.newaxis, start:stop]

                    if spikes.shape[0] != 200:
                        # norm spike counts to 200 trails
                        spike_bins = spike_bins / spikes.shape[0]
                        spike_bins = (spike_bins * 200).astype(int)

                    lbl = key + f'-{region:0>2d}' if single_channels else key + f'-{region}'
                    all_spike_bins.append(pd.Series(spike_bins[0], name=lbl))

                    if not generate_plots:
                        continue

                    # draw the heatmap, setup axis
                    ax.imshow(spike_bins,
                              aspect='auto',
                              extent=(hist[1][start], hist[1][stop], 0, 1),
                              vmin=0,
                              vmax=vmaxs[m_id])
                    ax.set_ylabel(region, rotation=0, va='center', labelpad=20)
                    ax.set_ylim(0, .4)

                    xt = [0, 2, 4, 6, 8, 10, 12, 14, 16, 18, 20]
                    ax.set_xticks(xt)
                    ax.set_xlim(xt[0], xt[-1])

                    if draw_gmm_fit and (spike_bins > 1).sum() > 15:
                        model = bgmm(10,
                                     covariance_type='diag',
                                     random_state=1,
                                     mean_prior=(8, ),
                                     covariance_prior=(.1, ),
                                     degrees_of_freedom_prior=10)

                        post_stim = np.logical_and(region_spikes > 0,
                                                   region_spikes < 20)
                        model.fit(region_spikes[post_stim, np.newaxis])

                        x = np.linspace(-50, 200, 2000).reshape(2000, 1)
                        logprob = model.score_samples(x)
                        pdf = np.exp(logprob)
                        ax.plot(x, pdf, '-w', linewidth=.8)

                    # the last plot gets a red stimulus indication bar
                    if region == spikes.columns.unique(0)[-1]:
                        ax.tick_params(bottom=True, labelbottom=True)
                        ax.set_xlabel('[ms]')
                        ax.hlines(0,
                                  0,
                                  8,
                                  clip_on=False,
                                  linewidth=6,
                                  color='r')
                        ax.annotate('Stimulus', (2.3, -.6),
                                    color='r',
                                    fontsize=15,
                                    annotation_clip=False)

                if generate_plots:
                    f = (f'{const.P["outputPath"]}/{plots_dest_dir_appdx}/'
                         f'onset_offset_{key}_{which_region}.png')
                    fig.savefig(f)
                    plt.close()

    on_off_scores = pd.concat(all_spike_bins, axis=1).T
    if csv_dest_dir_appdx:
        f = f'{const.P["outputPath"]}/{csv_dest_dir_appdx}/onset_offset_spikebins_{which_region}.csv'
        on_off_scores.to_csv(f)
    print(f'\nSpikes found: {nspikes:,}')
    return on_off_scores
Example #5
0
    def plot_paradigm(parad):
        print(parad)
        if 'C1' in parad or 'C2' in parad:
            dev = parad[-2:]
            std = 'C2' if dev == 'C1' else 'C1'

        to_regions = [
            'collapse_ctx_chnls', 'collapse_th_chnls',
            'drop_not_assigned_chnls'
        ]
        to_regions = dict.fromkeys(to_regions,
                                   True if chnls_to_regions else False)

        if grouping == 'paradigm_wise':
            data = fetch(paradigms=[parad])
        elif grouping == 'whisker_wise':
            if parad != 'MS':
                dev_data = fetch(paradigms=[parad],
                                 stim_types=['Deviant'],
                                 **to_regions)

                std_parad = parad.replace(dev, std)
                stim_types = ['Predeviant', 'UniquePredeviant']
                if all_stimuli:
                    stim_types.extend(
                        ['Standard', 'Postdeviant', 'UniquePostdeviant'])
                std_data = fetch(paradigms=[std_parad],
                                 stim_types=stim_types,
                                 **to_regions)
                if std_parad in ['DOC1', 'DOC2']:
                    std_data = fetch(paradigms=[std_parad],
                                     stim_types=['Standard'],
                                     **to_regions)

                data = {**std_data, **dev_data}
            else:
                data = fetch(paradigms=[parad], **to_regions)

        elif grouping == 'whisker_wise_reduced':
            dev_parads = [
                this_parad for this_parad in const.ALL_PARADIGMS
                if dev in this_parad
            ]
            std_parads = [
                this_parad.replace(dev, std) for this_parad in dev_parads
            ]
            if dev == 'C1':
                order = (
                    'DAC1-Deviant',
                    'O10C2-Predeviant',
                    'O10C1-Deviant',
                    'O25C2-UniquePredeviant',
                    'O25C1-Deviant',
                    'MS-C1',
                    'DOC2-Standard',
                    'DOC1-Deviant',
                )
            else:
                order = ('DAC2-Deviant', 'O10C1-Predeviant', 'O10C2-Deviant',
                         'O25C1-UniquePredeviant', 'O25C2-Deviant', 'MS-C2',
                         'DOC1-Standard', 'DOC2-Deviant')

            dev_data = fetch(paradigms=dev_parads, stim_types=['Deviant'])
            std_data = fetch(paradigms=std_parads,
                             stim_types=['Predeviant', 'UniquePredeviant'])
            std_data.update(
                fetch(paradigms=['DO' + std], stim_types=['Standard']))
            ms_data = fetch(paradigms=['MS'], stim_types=[dev])
            data = {**std_data, **dev_data, **ms_data}
            data = OrderedDict({
                key: data[key]
                for ord_key in order for key in data.keys() if ord_key in key
            })

        if grouping != 'whisker_wise_reduced':
            if parad != 'MS' and not all_stimuli:
                args = {'ncols': 2}
                width = 7
            else:
                args = {'ncols': 4}
                width = 13

        else:
            args = {
                'ncols': 8 + 4,
                'gridspec_kw': {
                    'width_ratios':
                    [.1, .015, .1, .1, .015, .1, .1, .015, .1, .015, .1, .1],
                }
            }
            width = 20
        fig, axes = plt.subplots(4,
                                 **args,
                                 sharex=True,
                                 sharey=True,
                                 figsize=(width, 13))
        fig.subplots_adjust(hspace=.06,
                            wspace=.03,
                            right=.98,
                            top=.86,
                            left=.13,
                            bottom=.07)

        [
            ax.tick_params(bottom=False,
                           left=False,
                           labelbottom=False,
                           labelleft=False) for ax in axes.flatten()
        ]
        if grouping != 'whisker_wise_reduced':
            title = const.PARAD_FULL[
                parad] + '- mean firing rates across 4 mice'
        else:
            title = parad + '- mean firing rates across 4 mice'

        if subtr_noise:
            title += '\nNOISE SUBTRACTED'
        fig.suptitle(title, size=14)
        plt.cm.get_cmap('gnuplot').set_gamma(.8)

        print()
        print()
        print()
        for mouse, i in zip(const.ALL_MICE, range(4)):
            mouse_dat = slice_data(data, [mouse],
                                   firingrate=True,
                                   frate_noise_subtraction=subtr_noise)
            if chnls_to_regions:
                axes[i, 0].tick_params(left=True, labelleft=True)
                lbls = list(mouse_dat.values())[0].index[::-1]
                axes[i, 0].set_yticks(np.arange(6.4 / 2, 32, 6.4))
                axes[i, 0].set_yticklabels(lbls)

            axes[i, 0].set_ylabel(mouse + '\nchannels',
                                  size=12,
                                  rotation=0,
                                  ha='right',
                                  va='center',
                                  x=-50)

            which_ax = 0
            for (key, frates), j in zip(mouse_dat.items(),
                                        range(args['ncols'])):

                im = axes[i, which_ax].imshow(frates,
                                              cmap='gnuplot',
                                              aspect='auto',
                                              extent=[-52.5, 202.5, -.5, 31.5],
                                              vmin=0,
                                              vmax=500)
                axes[i, which_ax].vlines(0,
                                         -.5,
                                         31.5,
                                         color='w',
                                         alpha=.6,
                                         linewidth=1)
                if i == 0:
                    stim_t = key[key.rfind('-') + 1:]
                    col = 'k' if stim_t != 'Deviant' else const.COLORS[
                        'deep_red']
                    if grouping == 'paradigm_wise' or parad == 'MS':
                        axes[i, which_ax].set_title(stim_t, color=col)
                    elif grouping == 'whisker_wise':
                        axes[i, which_ax].set_title(f'{parad[-2:]} {stim_t}',
                                                    color=col)
                    elif grouping == 'whisker_wise_reduced':
                        pard_full = const.PARAD_FULL[
                            key[key.find('-') + 1:key.rfind('-')]][:-3]
                        title = f'{dev} {stim_t}\n{pard_full}'
                        if 'MS' not in key:
                            axes[i, which_ax].set_title(title, color=col)
                        else:
                            axes[i, which_ax].set_title(stim_t +
                                                        '\nMany Standards',
                                                        color=col)

                elif i == 3:
                    axes[i, which_ax].tick_params(bottom=True,
                                                  labelbottom=True)
                    axes[i, which_ax].set_xlabel('ms')
                if (i == 0) and (which_ax == 0):
                    axes[i, which_ax].set_xlim((-52.5, 202.5))
                    axes[i, which_ax].set_xticks([-50, 0, 80, 160])

                    # colorbar and legend
                    at = (
                        .58,
                        .9,
                        2.5 / width,
                        .012,
                    )
                    cb = fig.colorbar(im,
                                      cax=fig.add_axes(at),
                                      orientation='horizontal')
                    cb.set_label('Mean Firing Rate in 5ms frame', size=12)
                    cb.ax.get_xaxis().set_label_position('top')

                which_ax += 1
                if grouping == 'whisker_wise_reduced' and which_ax in [
                        1, 4, 7, 9
                ]:
                    axes[i, which_ax].set_visible(False)
                    which_ax += 1
        return fig