Ejemplo n.º 1
0
def pareto(d, modelnames=None, groups=None, ax=None, **options):

    if ax is None:
        ax = plt.gca()

    n_parms = d.loc['n_parms']
    if modelnames is None:
        modelnames = d.columns

    mean_score = d.drop(index='n_parms').mean()

    for g in groups:
        r = re.compile(g)
        m = [i for i in mean_score.index if r.match(i)]
        xc = np.array([mean_score.loc[i] for i in m])
        n = np.array([n_parms[i] for i in m])
        xc = xc[np.argsort(n)]
        n = np.sort(n)

        ax.plot(n, xc, '-')

    #ax.text(0.1+hoffset, 0.1,
    #        'mean r={:.3f}'.format(np.mean(modelspec.meta['r_test'])))
    #ax.text(channels+0.1, modelspec.meta['r_test'][channels],
    #        '{:.3f}'.format(modelspec.meta['r_test'][channels]))

    ax_remove_box(ax)

    return ax
Ejemplo n.º 2
0
def plot_nl_io(module=None, xbounds=None, ax=None):

    if module is None:
        return
    if xbounds is None:
        xbounds = np.array([-1, 1])
    if ax:
        plt.sca(ax)
    else:
        ax = plt.gca()

    module_name, function_name = module['fn'].rsplit('.', 1)
    mod = importlib.import_module(module_name)
    fn = getattr(mod, '_' + function_name)
    keys = list(module['phi'].keys())
    chancount = len(module['phi'][keys[0]])
    d_in = np.linspace(xbounds[0], xbounds[1], 100)
    if chancount > 1:
        d_in = np.tile(d_in, (chancount, 1))
    d_out = fn(d_in, **module['phi'])
    plt.plot(d_in.T, d_out.T)

    ax_remove_box(ax)

    return ax
Ejemplo n.º 3
0
def perf_per_cell(modelspec, channels=0, ax=None, **options):

    if ax is None:
        ax = plt.gca()

    cellids = modelspec.meta.get('cellids',
                                 [modelspec.meta.get('cellid', None)])
    cellcount = len(cellids)
    ax.plot(modelspec.meta['r_test'], color='black')
    ax.plot(modelspec.meta['r_fit'], ls='--', color='blue')
    ax.plot(modelspec.meta['r_floor'], ls='-.', color='gray')

    ax.plot(channels, modelspec.meta['r_test'][channels], 'o')
    ax.set_xticks(np.arange(cellcount))
    ax.set_xticklabels(cellids)
    ylim = ax.get_ylim()
    if channels == 0:
        hoffset = 0
    else:
        hoffset = 1
    ax.text(0.1 + hoffset, 0.1,
            'mean r={:.3f}'.format(np.mean(modelspec.meta['r_test'])))
    ax.text(channels + 0.1, modelspec.meta['r_test'][channels, 0],
            '{:.3f}'.format(modelspec.meta['r_test'][channels, 0]))

    ax_remove_box(ax)

    return ax
Ejemplo n.º 4
0
Archivo: state.py Proyecto: nadoss/NEMS
def state_gain_plot(modelspec, ax=None, clim=None, title=None):
    for m in modelspec:
        if ('state_dc_gain' in m['fn']):
            g = m['phi']['g'][0, :]
            d = m['phi']['d'][0, :]
        elif ('state_dexp' in m['fn']):
            # hack, sdexp currently only supports single output channel
            g = m['phi']['g']
            d = m['phi']['d']
    MI = modelspec[0]['meta']['state_mod']
    state_chans = modelspec[0]['meta']['state_chans']
    if ax is not None:
        plt.sca(ax)
    plt.plot(d)
    plt.plot(g)
    plt.plot(MI)
    plt.xticks(np.arange(len(state_chans)), state_chans, fontsize=6)
    plt.legend(('baseline', 'gain', 'MI'))
    plt.plot(np.arange(len(state_chans)),
             np.zeros(len(state_chans)),
             'k--',
             linewidth=0.5)
    if title:
        plt.title(title)

    ax_remove_box(ax)
Ejemplo n.º 5
0
    def update_figure(self):
        p = self.parent

        c_count = self.recording[self.signal].shape[0]
        fs = self.recording[self.signal].fs
        start_bin = int(p.start_time * fs)
        stop_bin = int(p.stop_time * fs)

        # skip some channels, get names
        channel_names = self.recording[self.signal].chans[:c_count]
        skip_channels = ['baseline']
        if channel_names is not None:
            keep = np.array([(n not in skip_channels) for n in channel_names])
            channel_names = [
                channel_names[i] for i in range(c_count) if keep[i]
            ]
        else:
            keep = np.ones(c_count, dtype=bool)
            channel_names = None

        d = self.recording[self.signal].as_continuous()[keep,
                                                        start_bin:stop_bin]

        point = (isinstance(self.recording[self.signal],
                            nems.signal.PointProcess))
        tiled = (isinstance(self.recording[self.signal],
                            nems.signal.TiledSignal)
                 or 'stim' in self.recording[self.signal].name)
        if point:
            self.axes.imshow(d,
                             aspect='auto',
                             cmap='Greys',
                             interpolation='nearest',
                             origin='lower')
            self.axes.get_yaxis().set_visible(False)
        elif tiled:
            self.axes.imshow(d, aspect='auto', origin='lower')
        else:
            t = np.linspace(p.start_time, p.stop_time, d.shape[1])
            self.axes.plot(t, d.T)
            if (channel_names is not None) and len(channel_names) > 1:
                self.axes.legend(channel_names, frameon=False)

        self.axes.set_ylabel(self.signal)
        self.axes.autoscale(enable=True, axis='x', tight=True)
        ax_remove_box(self.axes)
        self.draw()

        if point or tiled:
            tick_labels = self.axes.get_xticklabels()

            #new_labels = [round((t.get_position()[0]+start_bin)/fs)
            #              if t.get_text() else ''
            #              for t in tick_labels]
            new_labels = [''] * len(tick_labels)
            self.axes.set_xticklabels(new_labels)
            self.draw()
Ejemplo n.º 6
0
def plot_timeseries(times, values, xlabel='Time', ylabel='Value', legend=None,
                    linestyle='-', linewidth=1,
                    ax=None, title=None, colors=None, **options):
    '''
    Plots a simple timeseries with one line for each pair of
    time and value vectors.
    Lines will be auto-colored according to matplotlib defaults.

    times : list of vectors
    values : list of vectors
    xlabel : str
    ylabel : str
    legend : list of strings
    linestyle, linewidth : pass-through options to plt.plot()

    TODO: expand this doc  -jacob 2-17-18
    '''
    if ax is not None:
        pass
        #plt.sca(ax)
    else:
        ax = plt.gca()

    cc = 0
    opt = {}
    h=[]
    mintime = np.inf
    maxtime = 0
    for t, v in zip(times, values):
        if colors is not None:
            opt = {'color': colors[cc % len(colors)]} #Wraparound to avoid crash
        if v.ndim==1:
            v=v[:,np.newaxis]
        for idx in range(v.shape[1]):
            gidx = np.isfinite(v[:,idx])
            h_=ax.plot(t[gidx], v[gidx, idx], linestyle=linestyle,
                        linewidth=linewidth, **opt)
            h = h + h_
        cc += 1
        if gidx.sum() > 0:
            mintime = np.min((mintime, np.min(t[gidx])))
            maxtime = np.max((maxtime, np.max(t[gidx])))
    #ax.set_margins(x=0)
    ax.set_xlabel(xlabel)
    ax.set_ylabel(ylabel)
    ax.set_xlim([mintime, maxtime])
    if legend:
        ax.legend(legend)
    if title:
        ax.set_title(title)

    ax_remove_box(ax)

    return ax, h
Ejemplo n.º 7
0
def plot_spectrogram(array, fs=None, ax=None, title=None, time_offset=0,
                     cmap=None, clim=None, extent=True, time_range=None, **options):


    if not ax:
        ax = plt.gca()

    if time_range is not None:
        if fs is not None:
            time_range = np.round(np.array(time_range)*fs).astype(int)
        log.debug('bin range: {}-{}'.format(time_range[0],time_range[1]))
        ax.imshow(array[:, np.arange(time_range[0],time_range[1])],
                  origin='lower', interpolation='none',
                  aspect='auto', cmap=cmap, clim=clim)
    elif extent:
        if fs is None:
            times = np.arange(0, array.shape[1])
        else:
            times = np.arange(0, array.shape[1])/fs-time_offset

        extent = [times[0], times[-1], 1, array.shape[0]]
        if extent[2]==extent[3]:
            extent[3]=2

        ax.imshow(array, origin='lower', interpolation='none',
                  aspect='auto', extent=extent, cmap=cmap, clim=clim)
    else:
        # maybe something had a bug and couldn't plot in seconds?
        ax.imshow(array, origin='lower', interpolation='none',
                  aspect='auto', cmap=cmap, clim=clim)

    ax.margins(x=0)

    # Override x-tic labels to display as real time
    # instead of time bin indices.
    #if fs is not None:
    #    locs = ax.get_xticks()
    #    newlabels = ["{:0.3f}".format(l/fs-time_offset) for l in locs]
    #    ax.set_xticklabels(newlabels)

    # TODO: Is there a way the colorbar can overlap part of the image
    # rather than shift it over?
    # cbar = plt.colorbar(fraction=0.05)
    # cbar.set_label('Amplitude')
    ax.set_xlabel('Time')
    ax.set_ylabel('Channel')
    if title:
        ax.set_title(title)

    ax_remove_box(ax)
    return ax
Ejemplo n.º 8
0
def plot_timeseries(times,
                    values,
                    xlabel='Time',
                    ylabel='Value',
                    legend=None,
                    linestyle='-',
                    linewidth=1,
                    ax=None,
                    title=None,
                    colors=None):
    '''
    Plots a simple timeseries with one line for each pair of
    time and value vectors.
    Lines will be auto-colored according to matplotlib defaults.

    times : list of vectors
    values : list of vectors
    xlabel : str
    ylabel : str
    legend : list of strings
    linestyle, linewidth : pass-through options to plt.plot()

    TODO: expand this doc  -jacob 2-17-18
    '''
    if ax is not None:
        plt.sca(ax)
    else:
        ax = plt.gca()

    cc = 0
    opt = {}
    for t, v in zip(times, values):
        if colors is not None:
            opt = {'color': colors[cc]}
        plt.plot(t, v, linestyle=linestyle, linewidth=linewidth, **opt)
        cc += 1

    plt.margins(x=0)
    plt.xlabel(xlabel)
    plt.ylabel(ylabel)
    ax.set_xlim([np.min(times), np.max(times)])
    if legend:
        plt.legend(legend)
    if title:
        plt.title(title, fontsize=8)

    ax_remove_box(ax)
Ejemplo n.º 9
0
def state_gain_plot(modelspec, ax=None, colors=None, clim=None, title=None, **options):

    state_idx = find_module('state', modelspec)
    g = modelspec.phi_mean[state_idx]['g']
    d = modelspec.phi_mean[state_idx]['d']
    ge = modelspec.phi_sem[state_idx]['g']
    de = modelspec.phi_sem[state_idx]['d']

    MI = modelspec[0]['meta']['state_mod']
    state_chans = modelspec[0]['meta']['state_chans']
    if ax is not None:
        plt.sca(ax)
    else:
        ax=plt.gca()
    if d.shape[0] > 1:
        opt={}
        for cc in range(d.shape[1]):
            if colors is not None:
                opt = {'color': colors[cc]}
            plt.plot(d[:,cc],'--', **opt)
            plt.plot(g[:,cc], **opt)
    else:
        plt.errorbar(np.arange(len(d[0, :])), d[0, :], de[0, :], color='blue')
        plt.errorbar(np.arange(len(g[0, :])), g[0, :], ge[0, :], color='red')
        dz = np.abs(d[0, :] / de[0, :])
        gz = np.abs(g[0, :] / ge[0, :])
        for i in range(len(gz)):
            if gz[i] > 2:
                ax.text(i, g[0, i] + np.sign(g[0, i]) * ge[0, i], state_chans[i],
                        color='red', ha='center', fontsize=6)
            elif dz[i] > 2:
                ax.text(i, d[0,i]+np.sign(d[0,i])*de[0,i], state_chans[i],
                        color='blue', ha='center', fontsize=6)

    #plt.plot(MI)
    #plt.xticks(np.arange(len(state_chans)), state_chans, fontsize=6)
    plt.legend(('baseline', 'gain'), frameon=False)
    plt.plot(np.arange(len(state_chans)),np.zeros(len(state_chans)),'k--',
             linewidth=0.5)
    if title:
        plt.title(title)

    ax_remove_box(ax)
Ejemplo n.º 10
0
    def update_figure(self):
        p = self.parent

        start_bin = int(p.start_time * self.fs)
        stop_bin = int(p.stop_time * self.fs)
        d = self.recording[self.signal].as_continuous()[self.keep,
                                                        start_bin:stop_bin]

        if self.point:
            self.axes.imshow(d,
                             aspect='auto',
                             cmap='Greys',
                             interpolation='nearest',
                             origin='lower')
            self.axes.get_yaxis().set_visible(False)
        elif self.tiled:
            self.axes.imshow(d, aspect='auto', origin='lower')
        else:
            t = np.linspace(p.start_time, p.stop_time, d.shape[1])
            self.axes.plot(t, d.T)
            if self.channel_names is not None:
                if len(self.channel_names) > 1:
                    self.axes.legend(self.channel_names, frameon=False)
            self.axes.set_xlim(p.start_time, p.stop_time)
            self.axes.set_ylim(ymin=self.ymin, ymax=self.ymax)

        #self.axes.set_xlim(p.start_time, p.stop_time)
        self.axes.set_ylabel(self.signal)
        #self.axes.autoscale(enable=True, axis='x', tight=True)
        ax_remove_box(self.axes)
        self.draw()

        if self.point or self.tiled:
            tick_labels = self.axes.get_xticklabels()

            #new_labels = [round((t.get_position()[0]+start_bin)/fs)
            #              if t.get_text() else ''
            #              for t in tick_labels]
            new_labels = [''] * len(tick_labels)
            self.axes.set_xticklabels(new_labels)
            self.draw()
Ejemplo n.º 11
0
    def __init__(self,
                 recording=None,
                 signal='stim',
                 parent=None,
                 *args,
                 **kwargs):
        MyMplCanvas.__init__(self, *args, **kwargs)
        if 'mask' in recording.signals:
            self.recording = recording.apply_mask()
        else:
            self.recording = recording
        self.signal = signal
        self.signal_obj = self.recording[self.signal]
        self.fs = self.signal_obj.fs
        self.parent = parent
        print("creating canvas: {}".format(signal))

        sig_array = self.signal_obj.as_continuous()
        # Chop off end of array (where it's all nan'd out after processing)
        # TODO: Make this smarter incase there are intermediate nans?
        self.max_time = sig_array.shape[-1] / self.recording[self.signal].fs

        point = (isinstance(self.recording[self.signal],
                            nems.signal.PointProcess))
        tiled = (isinstance(self.recording[self.signal],
                            nems.signal.TiledSignal)
                 or 'stim' in self.recording[self.signal].name
                 or 'contrast' in self.recording[self.signal].name)

        if (not point) and (not tiled):
            self.ymax = np.nanmax(sig_array) * 1.25
            self.ymin = min(0, np.nanmin(sig_array) * 1.25)

        self.point = point
        self.tiled = tiled

        # skip some channels, get names
        c_count = self.recording[self.signal].shape[0]
        if self.recording[self.signal].chans is None:
            channel_names = [''] * c_count
        else:
            channel_names = self.recording[self.signal].chans[:c_count]
        skip_channels = ['baseline']
        if channel_names is not None:
            keep = np.array([(n not in skip_channels) for n in channel_names])
            channel_names = [
                channel_names[i] for i in range(c_count) if keep[i]
            ]
        else:
            keep = np.ones(c_count, dtype=bool)
            channel_names = None
        self.keep = keep
        self.channel_names = channel_names

        p = self.parent

        d = sig_array[self.keep, :]

        if self.point:
            self.axes.imshow(d,
                             aspect='auto',
                             cmap='Greys',
                             interpolation='nearest',
                             origin='lower')
            self.axes.get_yaxis().set_visible(False)
        elif self.tiled:
            self.axes.imshow(d, aspect='auto', origin='lower')
        else:
            self.axes.plot(d.T)
            if self.channel_names is not None:
                if len(self.channel_names) > 1:
                    self.axes.legend(self.channel_names, frameon=False)
            self.axes.set_ylim(ymin=self.ymin, ymax=self.ymax)

        self.axes.set_xlim(p.start_time * self.fs, p.stop_time * self.fs)
        self.axes.set_ylabel(self.signal)
        ax_remove_box(self.axes)
        self.draw()

        tick_labels = self.axes.get_xticklabels()
        if self.point or self.tiled:
            new_labels = [''] * len(tick_labels)
            self.axes.set_xticklabels(new_labels)
            self.draw()
        else:
            # TODO: Still not working... Should turn bins to seconds
            fmt = tkr.FuncFormatter(self.seconds_formatter())
            self.axes.yaxis.set_major_formatter(fmt)
            self.draw()
Ejemplo n.º 12
0
Archivo: state.py Proyecto: nadoss/NEMS
def state_vars_psth_all(rec,
                        epoch,
                        psth_name='resp',
                        psth_name2='pred',
                        state_sig='state_raw',
                        ax=None,
                        colors=None,
                        channel=None,
                        decimate_by=1,
                        files_only=False,
                        modelspec=None):

    # TODO: Does using epochs make sense for these?
    if ax is not None:
        plt.sca(ax)

    newrec = rec.copy()
    fn = lambda x: x - newrec['pred']._data
    newrec['error'] = rec['resp'].transform(fn, 'error')

    fs = rec[psth_name].fs

    d = rec[psth_name].get_epoch_bounds('PreStimSilence')
    PreStimSilence = np.mean(np.diff(d)) - 0.5 / fs
    d = rec[psth_name].get_epoch_bounds('PostStimSilence')
    if d.size > 0:
        PostStimSilence = np.min(np.diff(d)) - 0.5 / fs
        dd = np.diff(d)
        dd = dd[dd > 0]
    else:
        dd = np.array([])
    if dd.size > 0:
        PostStimSilence = np.min(dd) - 0.5 / fs
    else:
        PostStimSilence = 0

    state_chan_list = rec['state'].chans
    low = np.zeros([0, 1])
    high = np.zeros([0, 1])
    lowE = np.zeros([0, 1])
    highE = np.zeros([0, 1])
    low2 = np.zeros([0, 1])
    high2 = np.zeros([0, 1])
    _high2 = None
    limitset = []
    if files_only:  #state_chan_list =['a','p','PASSIVE_1']
        state_chan_list = [
            s for s in state_chan_list
            if (s.startswith('FILE') | s.startswith('ACTIVE')
                | s.startswith('PASSIVE'))
        ]

    for state_chan in state_chan_list:

        _low, _high = state_mod_split(rec,
                                      epoch=epoch,
                                      psth_name=psth_name,
                                      channel=channel,
                                      state_sig=state_sig,
                                      state_chan=state_chan)
        _lowE, _highE = state_mod_split(newrec,
                                        epoch=epoch,
                                        psth_name='error',
                                        channel=channel,
                                        state_sig=state_sig,
                                        state_chan=state_chan,
                                        stat=scipy.stats.sem)
        if psth_name2 is not None:
            _low2, _high2 = state_mod_split(rec,
                                            epoch=epoch,
                                            psth_name=psth_name2,
                                            channel=channel,
                                            state_sig=state_sig,
                                            state_chan=state_chan)

        gapdur = _low.shape[0] / fs / 10
        gap = np.ones([int(np.ceil(fs * gapdur)), 1]) * np.nan
        pgap = np.ones(_low.shape) * np.nan
        if files_only:
            if state_chan == state_chan_list[0]:
                low = np.concatenate((high, gap, _low, gap), axis=0)
                lowE = np.concatenate((highE, gap, _lowE, gap), axis=0)
                high = np.concatenate((high, gap, pgap, gap), axis=0)
                highE = np.concatenate((highE, gap, pgap, gap), axis=0)
                if psth_name2 is not None:
                    low2 = np.concatenate((low2, gap, _low2, gap), axis=0)
                    high2 = np.concatenate((high2, gap, pgap, gap), axis=0)
            current_start = high.shape[0] / fs + gapdur
            if state_chan.startswith('ACTIVE'):
                _low = pgap
                _low2 = pgap
            else:
                _low = _high.copy()
                _low2 = _high2.copy()
                _high = pgap
                _high2 = pgap

        else:
            current_start = low.shape[0] / fs + gapdur

        low = np.concatenate((low, gap, _low, gap), axis=0)
        high = np.concatenate((high, gap, _high, gap), axis=0)
        lowE = np.concatenate((lowE, gap, _lowE, gap), axis=0)
        highE = np.concatenate((highE, gap, _highE, gap), axis=0)

        if psth_name2 is not None:
            low2 = np.concatenate((low2, gap, _low2, gap), axis=0)
            high2 = np.concatenate((high2, gap, _high2, gap), axis=0)

        limitset += [[
            current_start + PreStimSilence,
            current_start + _low.shape[0] / fs - PostStimSilence
        ]]

    if decimate_by > 1:
        low = scipy.signal.decimate(low, q=decimate_by, axis=1)
        high = scipy.signal.decimate(high, q=decimate_by, axis=1)
        if psth_name2 is not None:
            low2 = scipy.signal.decimate(low2, q=decimate_by, axis=1)
            high2 = scipy.signal.decimate(high2, q=decimate_by, axis=1)
        fs /= decimate_by

    tt = np.arange(high.shape[0]) / fs
    ax.fill_between(tt,
                    low[:, 0] - lowE[:, 0],
                    low[:, 0] + lowE[:, 0],
                    color=fill_colors['passive'])
    l1, = ax.plot(tt, low, ls='-', lw=1, color=line_colors['passive'])
    ax.fill_between(tt,
                    high[:, 0] - highE[:, 0],
                    high[:, 0] + highE[:, 0],
                    color=fill_colors['active'])
    l2, = ax.plot(tt, high, ls='-', lw=1, color=line_colors['active'])
    if psth_name2 is not None:
        ax.plot(tt, low2, ls='--', lw=1, color=line_colors['passive'])
        ax.plot(tt, high2, ls='--', lw=1, color=line_colors['active'])

    if not files_only:
        plt.legend((l1, l2), ('Lo', 'Hi'))

    ax.set_ylabel('sp/sec')
    ylim = ax.get_ylim()

    for ls, s in zip(limitset, state_chan_list):
        if modelspec is not None:
            sc = modelspec[0]['meta']['state_chans']
            mi = modelspec[0]['meta']['state_mod']
            sn = "{} ({:.2f})".format(s, mi[sc.index(s)])
        else:
            sn = s
        ax.plot(ls, [ylim[1], ylim[1]], 'k-', linewidth=2)
        lc = np.mean(ls)
        ax.text(lc, ylim[1], sn, ha='center', va='bottom', fontsize=6)
    ax.set_ylim([ylim[0], ylim[1] * 1.1])
    ax_remove_box(ax)
Ejemplo n.º 13
0
    def update_figure(self):
        self.axes.cla()

        epochs = self.recording.epochs
        p = self.parent
        valid_epochs = epochs[(epochs['start'] >= p.start_time)
                              & (epochs['end'] < p.stop_time)]
        if valid_epochs.size == 0:
            print('no valid epochs')
            # valid_epochs = valid_epochs.append([{'name': 'EXPT', 'start': p.start_time, 'end': p.stop_time}])
            return

        # On each refresh, keep the same keys but reform the lists of indices.
        self.epoch_groups = {k: [] for k in self.epoch_groups}
        for i, r in valid_epochs.iterrows():
            s = r['start']
            e = r['end']
            n = r['name']

            prefix = n.split('_')[0].split(',')[0].strip(' ').lower()
            if len(n) < 5:
                prefix = 'X'
            if prefix in [
                    'prestimsilence', 'poststimsilence', 'reference', 'target'
            ]:
                # skip
                pass
            elif prefix in self.epoch_groups:
                self.epoch_groups[prefix].append(i)
            else:
                self.epoch_groups[prefix] = [i]

        colors = [
            'Red', 'Orange', 'Green', 'LightBlue', 'DarkBlue', 'Purple',
            'Pink', 'Black', 'Gray'
        ]
        for i, g in enumerate(self.epoch_groups):
            for j in self.epoch_groups[g]:
                n = valid_epochs['name'][j]
                s = valid_epochs['start'][j]
                e = valid_epochs['end'][j]

                try:
                    n2 = valid_epochs['name'][j + 1]
                    s2 = valid_epochs['start'][j + 1]
                    e2 = valid_epochs['end'][j + 1]
                except KeyError:
                    # j is already the last epoch in the list
                    pass
                    n2 = n
                    s2 = s
                    e2 = e

                # If two epochs with the same name overlap,
                # extend the end of the first to the end of the second
                # and skip the second epoch.
                # Same if end goes past next start.
                if n == n2:
                    if (s2 < e) or (e > s2):
                        e = e2
                        j += 1
                    else:
                        pass

                # Don't plot text boxes outside of plot limits
                if s < p.start_time:
                    s = p.start_time
                elif e > p.stop_time:
                    e = p.stop_time

                x = np.array([s, e])
                y = np.array([i, i])
                self.axes.plot(x, y, '-', color=colors[i % len(colors)])
                #if len(self.epoch_groups[g]) < 5:
                self.axes.text(s,
                               i,
                               n,
                               va='bottom',
                               fontsize='small',
                               color=colors[i % len(colors)])

        self.axes.set_xlim([p.start_time, p.stop_time])
        self.axes.set_ylim([-0.5, i + 0.5])
        ax_remove_box(self.axes)
        self.draw()
Ejemplo n.º 14
0
    def update_figure(self):
        p = self.parent

        epochs = self.recording.epochs

        valid_epochs = epochs[(epochs['start'] >= p.start_time)
                              & (epochs['end'] < p.stop_time)]

        # On each refresh, keep the same keys but reform the lists of indices.
        self.epoch_groups = {k: [] for k in self.epoch_groups}
        for i, n, s, e in valid_epochs.itertuples():
            prefix = n.split('_')[0]
            if prefix in self.epoch_groups:
                self.epoch_groups[prefix].append(i)
            else:
                self.epoch_groups[prefix] = [i]

#        colors = ['Blue', 'Green', 'Yellow', 'Red']
#        k = 0
        for i, g in enumerate(self.epoch_groups):
            #            c = colors[k]
            #            k += 1
            #            if k == len(colors):
            #                k = 0
            for j in self.epoch_groups[g]:
                n = valid_epochs['name'][j]
                s = valid_epochs['start'][j]
                e = valid_epochs['end'][j]

                try:
                    n2 = valid_epochs['name'][j + 1]
                    s2 = valid_epochs['start'][j + 1]
                    e2 = valid_epochs['end'][j + 1]
                except KeyError:
                    # j is already the last epoch in the list
                    pass

                # If two epochs with the same name overlap,
                # extend the end of the first to the end of the second
                # and skip the second epoch.
                # Same if end goes past next start.
                if n == n2:
                    if (s2 < e) or (e > s2):
                        e = e2
                        j += 1
                    else:
                        pass

                # Don't plot text boxes outside of plot limits
                if s < p.start_time:
                    s = p.start_time
                elif e > p.stop_time:
                    p = p.stop_time

                x = np.array([s, e])
                y = np.array([i, i])
                #                props = {'facecolor': c, 'alpha': 0.07}
                #                self.axes.text(x[0], y[0], n, va='bottom',
                #                               bbox=props)
                self.axes.plot(x, y, 'k-')
                self.axes.text(s, i, n, va='bottom')
                self.axes.hold(True)


#        i = 0
#        for index, e in valid_epochs.iterrows():
#            x = np.array([e['start'],e['end']])
#            y = np.array([i, i])
#            self.axes.plot(x, y, 'k-')
#            self.axes.text(x[0], y[0], e['name'], va='bottom')
#            i += 1
#            self.axes.hold(True)

        self.axes.set_xlim([p.start_time, p.stop_time])
        self.axes.set_ylim([-0.5, i + 0.5])
        self.axes.set_ylabel('epochs')
        #self.axes.autoscale(enable=True, axis='x', tight=True)
        ax_remove_box(self.axes)
        self.axes.hold(False)
        self.draw()

        tick_labels = self.axes.get_xticklabels()

        new_labels = [''] * len(tick_labels)
        self.axes.set_xticklabels(new_labels)
        self.draw()
Ejemplo n.º 15
0
def plot_scatter(sig1,
                 sig2,
                 ax=None,
                 title=None,
                 smoothing_bins=False,
                 channels=0,
                 xlabel=None,
                 ylabel=None,
                 legend=True,
                 text=None,
                 force_square=False,
                 module=None,
                 **options):
    '''
    Uses the channels of sig1 to place points along the x axis, and channels of
    sig2 for distances along the y axis. If sig1 has one channel but sig2 has
    multiple channels, then all of sig2's channels will be plotted against the
    values from sig1. If sig1 has more than 1 channel, then sig2 must have the
    same number of channels, because XY coordinates will be determined from
    the same channel of both sig1 and sig2.

    Optional arguments:
        ax
        smoothing_bins: int
        xlabel
        ylabel
        legend
        module - NEMS module that applies an input-output transformation
          on data plotted from x to y axes. overlay data with curve from the
          module
    '''
    if (sig1.nchans > 1) or (sig2.nchans > 1):
        log.warning('sig1 or sig2 chancount > 1, using chan 0')
    if ax:
        plt.sca(ax)
    ax = plt.gca()

    m1 = sig1.as_continuous()
    m2 = sig2.as_continuous()

    # remove NaNs
    keepidx = np.isfinite(m1[0, :]) * np.isfinite(m2[0, :])
    m1 = m1[channels:(channels + 1), keepidx]
    m2 = m2[channels:(channels + 1), keepidx]

    for i in range(m2.shape[0]):
        if m1.shape[0] > 1:
            x = m1[[i], :]
        else:
            x = m1[[0], :]
        y = m2[[i], :]

        if smoothing_bins:

            # Concatenate and sort
            s2 = np.append(x, y, 0)
            s2 = s2[:, s2[0, :].argsort()]
            # ????
            bincount = np.min([smoothing_bins, s2.shape[1]])
            T = np.int(np.floor(s2.shape[1] / bincount))
            x0 = np.zeros(bincount)
            y0 = np.zeros(bincount)
            minx = np.min(x)
            stepsize = (np.max(x) - minx) / bincount
            for bb in range(bincount):
                kk = (x >= minx + bb * stepsize) & (x < minx +
                                                    (bb + 1) * stepsize)
                if np.sum(kk):
                    x0[bb] = np.mean(x[kk])
                    y0[bb] = np.mean(y[kk])
            kk = (np.abs(x0) > 0) & (np.abs(y0) > 0)
            x = x0
            y = y0


#            s2 = s2[:, 0:(T * bincount)]
#            s2 = np.reshape(s2, [2, bincount, T])
#            s2 = np.mean(s2, 2)
#            s2 = np.squeeze(s2)
#            x = s2[0, :]
#            y = s2[1, :]

        chan_name = 'Channel {}'.format(i) if not sig2.chans else sig2.chans[i]
        plt.scatter(x, y, label=chan_name, s=2, color='darkgray')

    if module is not None:
        xbounds = ax.get_xbound()
        plot_nl_io(module, xbounds, ax)

    xlabel = xlabel if xlabel else sig1.name
    plt.xlabel(xlabel)

    ylabel = ylabel if ylabel else sig2.name
    plt.ylabel(ylabel)

    if legend and sig2.nchans > 1:
        plt.legend(loc='lower right')

    if title:
        plt.title(title)

    if text is not None:
        # Figure out where to align text box
        axes = plt.gca()
        ymin, ymax = axes.get_ylim()
        xmin, xmax = axes.get_xlim()
        if ymin == ymax:
            ymax = ymin + 1
        if xmin == xmax:
            xmax = xmin + 1
        x_coord = xmin + (xmax - xmin) / 50
        y_coord = ymax - (ymax - ymin) / 20
        plt.text(x_coord, y_coord, text, verticalalignment='top')

    if force_square:
        axes = plt.gca()
        ymin, ymax = axes.get_ylim()
        xmin, xmax = axes.get_xlim()
        axes.set_aspect(abs(xmax - xmin) / abs(ymax - ymin))

    ax_remove_box(ax)

    return ax
Ejemplo n.º 16
0
def strf_heatmap(modelspec,
                 ax=None,
                 clim=None,
                 show_factorized=True,
                 title='STRF',
                 fs=None,
                 chans=None,
                 wc_idx=0,
                 fir_idx=0,
                 interpolation='none',
                 absolute_value=False,
                 cmap='RdYlBu_r',
                 manual_extent=None,
                 show_cbar=True,
                 **options):
    """
    chans: list
       if not None, label each row of the strf with the corresponding
       channel name
    interpolation: string, tuple
       if string, passed on as parameter to imshow
       if tuple, ndimage "zoom" by a factor of (x,y) on each dimension
    """
    if fs is None:
        try:
            fs = modelspec.recording['stim'].fs
        except:
            pass
    wcc = _get_wc_coefficients(modelspec, idx=wc_idx)
    firc = _get_fir_coefficients(modelspec, idx=fir_idx, fs=fs)
    fir_mod = find_module('fir', modelspec, find_all_matches=True)[fir_idx]

    if wcc is None and firc is None:
        log.warn('Unable to generate STRF.')
        return
    elif wcc is None and firc is not None:
        strf = np.array(firc)
        show_factorized = False
    elif wcc is not None and firc is None:
        strf = np.array(wcc).T
        show_factorized = False
    elif 'filter_bank' in modelspec[fir_mod]['fn']:
        wc_coefs = np.array(wcc).T
        fir_coefs = np.array(firc)

        bank_count = modelspec[fir_mod]['fn_kwargs']['bank_count']
        chan_count = wcc.shape[0]
        bank_chans = int(chan_count / bank_count)
        if wc_coefs.shape[1] == fir_coefs.shape[0]:
            strfs = [
                wc_coefs[:, (bank_chans * i):(bank_chans * (i + 1))]
                @ fir_coefs[(bank_chans * i):(bank_chans * (i + 1)), :]
                for i in range(bank_count)
            ]
            for i in range(bank_count):
                m = np.max(np.abs(strfs[i]))
                if m:
                    strfs[i] = strfs[i] / m
                if i > 0:
                    gap = np.full([strfs[i].shape[0], 1], np.nan)
                    strfs[i] = np.concatenate(
                        (gap, strfs[i] / np.max(np.abs(strfs[i]))), axis=1)

            strf = np.concatenate(strfs, axis=1)
        else:
            strf = fir_coefs
        show_factorized = False
    else:
        wc_coefs = np.array(wcc).T
        fir_coefs = np.array(firc)
        if wc_coefs.shape[1] == fir_coefs.shape[0]:
            strf = wc_coefs @ fir_coefs
        else:
            strf = fir_coefs
            show_factorized = False

    if not clim:
        cscale = np.nanmax(np.abs(strf.reshape(-1)))
        clim = [-cscale, cscale]
    else:
        cscale = np.max(np.abs(clim))

    if type(interpolation) is str:
        if interpolation == 'none':
            pass
        else:
            show_factorized = False
    else:
        s = strf.shape
        strf = zoom(strf, interpolation)
        fs = fs * interpolation[1]
        interpolation = 'none'

    if show_factorized:
        # Never rescale the STRF or CLIM!
        # The STRF should be the final word and respect input colormap and clim
        # However: rescaling WC and FIR coefs to make them more visible is ok
        wc_max = np.nanmax(np.abs(wc_coefs[:]))
        fir_max = np.nanmax(np.abs(fir_coefs[:]))
        wc_coefs = wc_coefs * (cscale / wc_max)
        fir_coefs = fir_coefs * (cscale / fir_max)

        n_inputs, _ = wc_coefs.shape
        nchans, ntimes = fir_coefs.shape
        gap = np.full([nchans + 1, nchans + 1], np.nan)
        horz_space = np.full([1, ntimes], np.nan)
        vert_space = np.full([n_inputs, 1], np.nan)
        top_right = np.concatenate([fir_coefs, horz_space], axis=0)
        top_left = np.concatenate([wc_coefs, vert_space], axis=1)
        bot = np.concatenate([top_left, strf], axis=1)
        top = np.concatenate([gap, top_right], axis=1)
        everything = np.concatenate([top, bot], axis=0)
        skip = nchans + 1
    else:
        everything = strf
        skip = 0

    if absolute_value:
        everything = np.abs(everything)

    plot_heatmap(everything,
                 xlabel='Lag (s)',
                 ylabel='Channel In',
                 ax=ax,
                 skip=skip,
                 title=title,
                 fs=fs,
                 interpolation=interpolation,
                 cmap=get_setting('FILTER_CMAP'),
                 manual_extent=manual_extent,
                 show_cbar=show_cbar)
    ax_remove_box(left=True, bottom=True, ticks=True)

    if chans is not None:
        for i, c in enumerate(chans):
            plt.text(0, i + nchans + 1, c, verticalalignment='center')
Ejemplo n.º 17
0
Archivo: state.py Proyecto: nadoss/NEMS
def state_vars_timeseries(rec,
                          modelspec,
                          ax=None,
                          state_colors=None,
                          decimate_by=1,
                          channel=None):

    if ax is not None:
        plt.sca(ax)
    pred = rec['pred']
    resp = rec['resp']
    fs = resp.fs

    chanidx = get_channel_number(resp, channel)

    r1 = resp.as_continuous()[chanidx, :].T * fs
    p1 = pred.as_continuous()[chanidx, :].T * fs
    nnidx = np.isfinite(p1)
    r1 = r1[nnidx]
    p1 = p1[nnidx]

    if decimate_by > 1:
        r1 = scipy.signal.decimate(r1, q=decimate_by, axis=0)
        p1 = scipy.signal.decimate(p1, q=decimate_by, axis=0)
        fs /= decimate_by

    t = np.arange(len(r1)) / fs

    plt.plot(t, r1, linewidth=1, color='gray')
    plt.plot(t, p1, linewidth=1, color='black')
    print(p1.shape)
    mmax = np.nanmax(p1) * 0.8

    if 'state' in rec.signals.keys():
        s = None
        g = None
        d = None
        for m in modelspec:
            if 'state_dc_gain' in m['fn']:
                g = np.array(m['phi']['g'])
                d = np.array(m['phi']['d'])
                if len(g) < 10:
                    s = ",".join(rec["state"].chans)
                    g_string = np.array2string(g, precision=3)
                    d_string = np.array2string(d, precision=3)
                    s += " g={} d={} ".format(g_string, d_string)
                else:
                    s = None

        num_vars = rec['state'].shape[0]
        ts = rec['state'].as_continuous().copy()
        if state_colors is None:
            state_colors = [None] * num_vars
        offset = -1.25 * mmax
        for i in range(1, num_vars):

            st = ts[i, :].T
            if len(np.unique(st)) == 2:
                # special, binary variable, keep in one row
                m = np.array([np.min(st)])
                st = np.concatenate((m, st, m))
                dinc = np.argwhere(np.diff(st) > 0)
                ddec = np.argwhere(np.diff(st) < 0)
                for x0, x1 in zip(dinc, ddec):
                    plt.plot([x0 / fs, x1 / fs], [offset, offset],
                             lw=2,
                             color=state_colors[i - 1])
                tstr = "{}".format(rec['state'].chans[i])
                plt.text(x0 / fs, offset, tstr, fontsize=6)
                #print("{} {} {}".format(rec['state'].chans[i], x0/fs, offset))
            else:
                # non-binary variable, plot in own row
                # figure out gain
                if g is not None:
                    if g.ndim == 1:
                        tstr = "{} (d={:.3f},g={:.3f})".format(
                            rec['state'].chans[i], d[i], g[i])
                    else:
                        tstr = "{} (d={:.3f},g={:.3f})".format(
                            rec['state'].chans[i], d[0, i], g[0, i])
                else:
                    tstr = "{}".format(rec['state'].chans[i])
                if decimate_by > 1:
                    st = scipy.signal.decimate(st[nnidx],
                                               q=decimate_by,
                                               axis=0)
                else:
                    st = st[nnidx]

                st = st / np.nanmax(st) * mmax + offset
                plt.plot(t, st, linewidth=1, color=state_colors[i - 1])
                plt.text(t[0], offset, tstr, fontsize=6)

                offset -= 1.25 * mmax

        ax = plt.gca()
        # plt.text(0.5, 0.9, s, transform=ax.transAxes,
        #         horizontalalignment='center')
        # if s:
        #    plt.title(s, fontsize=8)
    plt.xlabel('time (s)')
    plt.axis('tight')

    ax_remove_box(ax)
Ejemplo n.º 18
0
    t = "n={}/{}\np={:.3e}\nmd={:.4f}".format(np.sum(si), si.shape[0], p, md)
    ax.text(tx, ty, t, va='top')
    ax.set_xlabel('FG-BG gain')

    ax = axes[0, 2]

    ax.plot(rdiff[nsi], gdiff[nsi], 'o', color='gray', mec='w', mew=1)
    ax.plot(rdiff[si], gdiff[si], 'o', color='#83428c', mec='w', mew=1)
    r, p = pearsonr(rdiff[nsi + si], gdiff[nsi + si])

    x = np.polyfit(rdiff, gdiff, 1)
    x0 = np.array(ax.get_xlim())
    y0 = x0 * x[0] + x[1]

    ax.plot(x0, y0, 'k--')
    ax_remove_box(ax)
    ax.set_xlabel('deltaR')
    ax.set_ylabel('deltaG')
    ax.set_title('R={:.3f} p={:.4e}'.format(r, p))

    ax = axes[1, 0]
    beta_comp(b_S,
              f_S,
              n1='bg_S',
              n2='fg_S',
              ax=ax,
              hist_range=[-bound, bound],
              highlight=si,
              title=bs + " (shf)")

    ax = axes[1, 1]
Ejemplo n.º 19
0
def state_gain_plot(modelspec,
                    rec,
                    state_sig='state_raw',
                    ax=None,
                    colors=None,
                    clim=None,
                    title=None,
                    **options):

    state_chan_list = rec[state_sig].chans
    state_idx = find_module('state', modelspec)
    g = modelspec.phi_mean[state_idx]['g'].copy()
    ge = modelspec.phi_sem[state_idx]['g']
    if modelspec[state_idx]['fn'] == 'nems.modules.state.state_gain':
        d = None
        de = None
        gainoffset = modelspec[state_idx]['fn_kwargs']['gainoffset']
        g += modelspec[state_idx]['fn_kwargs']['gainoffset']
    else:
        d = modelspec.phi_mean[state_idx]['d']
        de = modelspec.phi_sem[state_idx]['d']

    MI = modelspec[0]['meta']['state_mod']
    state_chans = modelspec[0]['meta']['state_chans']
    if ax is None:
        ax = plt.gca()
    if g.shape[0] > 1:
        opt = {}
        for cc in range(g.shape[1]):
            if colors is not None:
                opt = {'color': colors[cc]}
            if d is not None:
                ax.plot(d[:, cc], '--', **opt)
            ax.plot(g[:, cc], **opt)
    else:
        if d is not None:
            ax.errorbar(np.arange(len(d[0, :])),
                        d[0, :],
                        de[0, :],
                        color='blue')
            dz = np.abs(d[0, :] / de[0, :])
        ax.errorbar(np.arange(len(g[0, :])), g[0, :], ge[0, :], color='red')
        gz = np.abs(g[0, :] / ge[0, :])
        for i in range(len(gz)):
            if gz[i] > 2:
                ax.text(i,
                        g[0, i] + np.sign(g[0, i]) * ge[0, i],
                        state_chans[i],
                        color='red',
                        ha='center',
                        fontsize=6)
            elif d is not None and dz[i] > 2:
                ax.text(i,
                        d[0, i] + np.sign(d[0, i]) * de[0, i],
                        state_chans[i],
                        color='blue',
                        ha='center',
                        fontsize=6)

    #ax.plot(MI)
    #ax.xticks(np.arange(len(state_chans)), state_chans, fontsize=6)
    if d is None:
        ax.legend(state_chan_list, frameon=False)
        ax.set_ylabel('Gain')
        ax.plot(np.arange(len(state_chans)),
                gainoffset * np.ones(len(state_chans)),
                'k--',
                linewidth=0.5)
    else:
        ax.legend(('baseline', 'gain'), frameon=False)
        ax.plot(np.arange(len(state_chans)),
                np.zeros(len(state_chans)),
                'k--',
                linewidth=0.5)
    if title:
        ax.title(title)

    ax_remove_box(ax)
Ejemplo n.º 20
0
Archivo: drc.py Proyecto: LBHB/nems_db
def test_DRC_with_contrast(ms=30,
                           normalize=True,
                           fs=100,
                           bands=1,
                           percentile=70,
                           n_segments=8,
                           example_batch=289,
                           example_cell='TAR010c-13-1',
                           voc_batch=263,
                           voc_cell='tul034b-b1'):
    '''
    Plot a sample DRC stimulus next to assigned contrast
    and calculated contrast.
    '''
    loadkey = 'ozgf.fs%d.ch18' % fs
    nat_plot_bins = 1050
    voc_plot_bins = 1110
    nat_seconds = np.arange(0, nat_plot_bins) / fs
    voc_seconds = np.arange(0, voc_plot_bins) / fs
    recording_uri = generate_recording_uri(cellid=example_cell,
                                           batch=example_batch,
                                           loadkey=loadkey,
                                           stim=True)
    nat_rec = load_recording(recording_uri)
    nat_rec = make_contrast_signal(nat_rec,
                                   name='continuous',
                                   continuous=True,
                                   ms=ms,
                                   percentile=percentile,
                                   bands=bands,
                                   normalize=normalize)

    epochs = nat_rec['resp'].epochs
    stim_epochs = ep.epoch_names_matching(epochs, 'STIM_')
    pre_silence = silence_duration(epochs, 'PreStimSilence')
    post_silence = silence_duration(epochs, 'PostStimSilence')
    indices = np.arange(20, dtype=np.int32)
    for s in stim_epochs:
        row = epochs[epochs.name == s]
        st = row['start'].values[0]
        end = row['end'].values[0]
        stim_start = int((st + pre_silence) * fs)
        stim_end = int((end - post_silence) * fs)
        indices = np.append(
            indices, np.arange(stim_start - 20, stim_end + 20, dtype=np.int32))

    nat_stim = nat_rec['stim'].as_continuous()[:, indices][:, :nat_plot_bins]
    nat_contrast = nat_rec['continuous'].as_continuous(
    )[:, indices][:, :nat_plot_bins]
    nat_summed = np.sum(nat_contrast, axis=0)
    nat_summed /= np.max(nat_summed)  # norm 0 to 1 just to match axes

    voc_rec_uri = generate_recording_uri(cellid=voc_cell,
                                         batch=voc_batch,
                                         loadkey=loadkey,
                                         stim=True)
    voc_rec = load_recording(voc_rec_uri)
    voc_rec = make_contrast_signal(voc_rec,
                                   name='continuous',
                                   continuous=True,
                                   ms=ms,
                                   percentile=percentile,
                                   bands=bands,
                                   normalize=normalize)

    # Force voc and noise to be interleaved for visualization
    epochs = voc_rec.epochs
    stim_epochs = ep.epoch_names_matching(epochs, 'STIM_')
    vocs = sorted([s for s in stim_epochs if '0dB' not in s])
    noise = sorted(list(set(stim_epochs) - set(vocs)))
    indices = np.array([], dtype=np.int32)
    for v, n in zip(vocs, noise):
        voc_row = epochs[epochs.name == v]
        voc_start = int(voc_row['start'].values[0] * fs)
        voc_end = int(voc_row['end'].values[0] * fs)
        indices = np.append(indices,
                            np.arange(voc_start, voc_end, dtype=np.int32))

        noise_row = epochs[epochs.name == n]
        noise_start = int(noise_row['start'].values[0] * fs)
        noise_end = int(noise_row['end'].values[0] * fs)
        indices = np.append(indices,
                            np.arange(noise_start, noise_end, dtype=np.int32))

    voc_stim = voc_rec['stim'].as_continuous()[:, indices][:, :voc_plot_bins]
    voc_contrast = voc_rec['continuous'].as_continuous(
    )[:, indices][:, :voc_plot_bins]
    voc_summed = np.sum(voc_contrast, axis=0)
    voc_summed /= np.max(voc_summed)  # norm 0 to 1 just to match axes

    fig, ((a2, a3), (a5, a6), (a8, a9)) = plt.subplots(3, 2, figsize=(3.5, 4))

    # Natural Sound
    plt.sca(a2)
    #plt.title('Nat. Sound')
    plt.imshow(nat_stim, cmap=spectrogram_cmap, aspect='auto', origin='lower')
    a2.get_xaxis().set_visible(False)
    a2.get_yaxis().set_visible(False)
    ax_remove_box(a2)
    #plt.ylabel('Freq. Channel')

    plt.sca(a5)
    #plt.title('Contrast')
    plt.imshow(nat_contrast, aspect='auto', origin='lower', cmap=contrast_cmap)
    a5.get_xaxis().set_visible(False)
    a5.get_yaxis().set_visible(False)
    ax_remove_box(a5)

    plt.sca(a8)
    #plt.title('Summed')
    plt.plot(nat_seconds, nat_summed, color='black', linewidth=0.5)
    a8.set_ylim(-0.1, 1.1)
    a8.get_yaxis().set_visible(False)
    #plt.xlabel('Time (s)')
    #plt.ylabel('Summed Contrast (A.U.)')
    a8.set_xlim(nat_seconds.min(), nat_seconds.max())
    ax_remove_box(a8)

    # Voc in noise
    plt.sca(a3)
    #plt.title('Voc. in Noise')
    plt.imshow(voc_stim, cmap=spectrogram_cmap, aspect='auto', origin='lower')
    a3.get_xaxis().set_visible(False)
    a3.get_yaxis().set_visible(False)
    ax_remove_box(a3)

    plt.sca(a6)
    #plt.title('Continuous Calculated Contrast')
    plt.imshow(voc_contrast, aspect='auto', origin='lower', cmap=contrast_cmap)
    a6.get_xaxis().set_visible(False)
    a6.get_yaxis().set_visible(False)
    ax_remove_box(a6)

    plt.sca(a9)
    #plt.title('Summed')
    plt.plot(voc_seconds, voc_summed, color='black', linewidth=0.5)
    a9.set_ylim(-0.1, 1.1)
    #a9.get_yaxis().tick_right()
    a9.get_yaxis().set_visible(False)
    a9.set_xlim(voc_seconds.min(), voc_seconds.max())
    ax_remove_box(a9)

    fig2 = plt.figure()
    text = ("top: spectrogram\n"
            "middle: contrast\n"
            "bottom: summed contrast\n"
            "left: 289, right: 263\n"
            "x: Time (s)\n"
            "summed 0 to 1, contrast arb., spec freq increasing")
    plt.text(0.1, 0.5, text)

    return fig, fig2
Ejemplo n.º 21
0
def plot_heatmap(array,
                 xlabel='Time',
                 ylabel='Channel',
                 ax=None,
                 cmap=None,
                 clim=None,
                 skip=0,
                 title=None,
                 fs=None,
                 interpolation='none',
                 manual_extent=None,
                 show_cbar=True,
                 fontsize=7,
                 **options):
    '''
    A wrapper for matplotlib's plt.imshow() to ensure consistent formatting.
    '''
    if ax is not None:
        plt.sca(ax)
    else:
        ax = plt.gca()

    if cmap is None:
        cmap = get_setting('WEIGHTS_CMAP')

    # Make sure array is converted to ndarray if passed as list
    array = np.array(array)

    if clim is None:
        mmax = np.nanmax(np.abs(array.reshape(-1)))
        clim = [-mmax, mmax]

    if manual_extent is not None:
        extent = manual_extent
    elif fs is not None:
        extent = [
            0.5 / fs, (array.shape[1] + 0.5) / fs, 0.5, array.shape[0] + 0.5
        ]
    else:
        extent = None

    plt.imshow(array,
               aspect='auto',
               origin='lower',
               cmap=cmap,
               clim=clim,
               interpolation=interpolation,
               extent=extent)

    # Force integer tick labels, skipping gaps
    #y, x = array.shape

    #plt.xticks(np.arange(skip, x), np.arange(0, x-skip))
    #plt.xticklabels(np.arange(0, x-skip))
    #plt.yticks(np.arange(skip, y), np.arange(0, y-skip))
    #plt.yticklabels(np.arange(0, y-skip))
    plt.xlabel(xlabel)
    plt.ylabel(ylabel)

    if show_cbar:
        # Set the color bar
        cbar = plt.colorbar()
        cbar.ax.tick_params(labelsize=fontsize)
        cbar.ax.yaxis.set_major_locator(plt.MaxNLocator(3))
        cbar.set_label('Gain', fontsize=fontsize)
        cbar.outline.set_edgecolor('white')

    if title is not None:
        plt.title(title)

    ax_remove_box(ax)
    return ax