def pareto(d, modelnames=None, groups=None, ax=None, **options): if ax is None: ax = plt.gca() n_parms = d.loc['n_parms'] if modelnames is None: modelnames = d.columns mean_score = d.drop(index='n_parms').mean() for g in groups: r = re.compile(g) m = [i for i in mean_score.index if r.match(i)] xc = np.array([mean_score.loc[i] for i in m]) n = np.array([n_parms[i] for i in m]) xc = xc[np.argsort(n)] n = np.sort(n) ax.plot(n, xc, '-') #ax.text(0.1+hoffset, 0.1, # 'mean r={:.3f}'.format(np.mean(modelspec.meta['r_test']))) #ax.text(channels+0.1, modelspec.meta['r_test'][channels], # '{:.3f}'.format(modelspec.meta['r_test'][channels])) ax_remove_box(ax) return ax
def plot_nl_io(module=None, xbounds=None, ax=None): if module is None: return if xbounds is None: xbounds = np.array([-1, 1]) if ax: plt.sca(ax) else: ax = plt.gca() module_name, function_name = module['fn'].rsplit('.', 1) mod = importlib.import_module(module_name) fn = getattr(mod, '_' + function_name) keys = list(module['phi'].keys()) chancount = len(module['phi'][keys[0]]) d_in = np.linspace(xbounds[0], xbounds[1], 100) if chancount > 1: d_in = np.tile(d_in, (chancount, 1)) d_out = fn(d_in, **module['phi']) plt.plot(d_in.T, d_out.T) ax_remove_box(ax) return ax
def perf_per_cell(modelspec, channels=0, ax=None, **options): if ax is None: ax = plt.gca() cellids = modelspec.meta.get('cellids', [modelspec.meta.get('cellid', None)]) cellcount = len(cellids) ax.plot(modelspec.meta['r_test'], color='black') ax.plot(modelspec.meta['r_fit'], ls='--', color='blue') ax.plot(modelspec.meta['r_floor'], ls='-.', color='gray') ax.plot(channels, modelspec.meta['r_test'][channels], 'o') ax.set_xticks(np.arange(cellcount)) ax.set_xticklabels(cellids) ylim = ax.get_ylim() if channels == 0: hoffset = 0 else: hoffset = 1 ax.text(0.1 + hoffset, 0.1, 'mean r={:.3f}'.format(np.mean(modelspec.meta['r_test']))) ax.text(channels + 0.1, modelspec.meta['r_test'][channels, 0], '{:.3f}'.format(modelspec.meta['r_test'][channels, 0])) ax_remove_box(ax) return ax
def state_gain_plot(modelspec, ax=None, clim=None, title=None): for m in modelspec: if ('state_dc_gain' in m['fn']): g = m['phi']['g'][0, :] d = m['phi']['d'][0, :] elif ('state_dexp' in m['fn']): # hack, sdexp currently only supports single output channel g = m['phi']['g'] d = m['phi']['d'] MI = modelspec[0]['meta']['state_mod'] state_chans = modelspec[0]['meta']['state_chans'] if ax is not None: plt.sca(ax) plt.plot(d) plt.plot(g) plt.plot(MI) plt.xticks(np.arange(len(state_chans)), state_chans, fontsize=6) plt.legend(('baseline', 'gain', 'MI')) plt.plot(np.arange(len(state_chans)), np.zeros(len(state_chans)), 'k--', linewidth=0.5) if title: plt.title(title) ax_remove_box(ax)
def update_figure(self): p = self.parent c_count = self.recording[self.signal].shape[0] fs = self.recording[self.signal].fs start_bin = int(p.start_time * fs) stop_bin = int(p.stop_time * fs) # skip some channels, get names channel_names = self.recording[self.signal].chans[:c_count] skip_channels = ['baseline'] if channel_names is not None: keep = np.array([(n not in skip_channels) for n in channel_names]) channel_names = [ channel_names[i] for i in range(c_count) if keep[i] ] else: keep = np.ones(c_count, dtype=bool) channel_names = None d = self.recording[self.signal].as_continuous()[keep, start_bin:stop_bin] point = (isinstance(self.recording[self.signal], nems.signal.PointProcess)) tiled = (isinstance(self.recording[self.signal], nems.signal.TiledSignal) or 'stim' in self.recording[self.signal].name) if point: self.axes.imshow(d, aspect='auto', cmap='Greys', interpolation='nearest', origin='lower') self.axes.get_yaxis().set_visible(False) elif tiled: self.axes.imshow(d, aspect='auto', origin='lower') else: t = np.linspace(p.start_time, p.stop_time, d.shape[1]) self.axes.plot(t, d.T) if (channel_names is not None) and len(channel_names) > 1: self.axes.legend(channel_names, frameon=False) self.axes.set_ylabel(self.signal) self.axes.autoscale(enable=True, axis='x', tight=True) ax_remove_box(self.axes) self.draw() if point or tiled: tick_labels = self.axes.get_xticklabels() #new_labels = [round((t.get_position()[0]+start_bin)/fs) # if t.get_text() else '' # for t in tick_labels] new_labels = [''] * len(tick_labels) self.axes.set_xticklabels(new_labels) self.draw()
def plot_timeseries(times, values, xlabel='Time', ylabel='Value', legend=None, linestyle='-', linewidth=1, ax=None, title=None, colors=None, **options): ''' Plots a simple timeseries with one line for each pair of time and value vectors. Lines will be auto-colored according to matplotlib defaults. times : list of vectors values : list of vectors xlabel : str ylabel : str legend : list of strings linestyle, linewidth : pass-through options to plt.plot() TODO: expand this doc -jacob 2-17-18 ''' if ax is not None: pass #plt.sca(ax) else: ax = plt.gca() cc = 0 opt = {} h=[] mintime = np.inf maxtime = 0 for t, v in zip(times, values): if colors is not None: opt = {'color': colors[cc % len(colors)]} #Wraparound to avoid crash if v.ndim==1: v=v[:,np.newaxis] for idx in range(v.shape[1]): gidx = np.isfinite(v[:,idx]) h_=ax.plot(t[gidx], v[gidx, idx], linestyle=linestyle, linewidth=linewidth, **opt) h = h + h_ cc += 1 if gidx.sum() > 0: mintime = np.min((mintime, np.min(t[gidx]))) maxtime = np.max((maxtime, np.max(t[gidx]))) #ax.set_margins(x=0) ax.set_xlabel(xlabel) ax.set_ylabel(ylabel) ax.set_xlim([mintime, maxtime]) if legend: ax.legend(legend) if title: ax.set_title(title) ax_remove_box(ax) return ax, h
def plot_spectrogram(array, fs=None, ax=None, title=None, time_offset=0, cmap=None, clim=None, extent=True, time_range=None, **options): if not ax: ax = plt.gca() if time_range is not None: if fs is not None: time_range = np.round(np.array(time_range)*fs).astype(int) log.debug('bin range: {}-{}'.format(time_range[0],time_range[1])) ax.imshow(array[:, np.arange(time_range[0],time_range[1])], origin='lower', interpolation='none', aspect='auto', cmap=cmap, clim=clim) elif extent: if fs is None: times = np.arange(0, array.shape[1]) else: times = np.arange(0, array.shape[1])/fs-time_offset extent = [times[0], times[-1], 1, array.shape[0]] if extent[2]==extent[3]: extent[3]=2 ax.imshow(array, origin='lower', interpolation='none', aspect='auto', extent=extent, cmap=cmap, clim=clim) else: # maybe something had a bug and couldn't plot in seconds? ax.imshow(array, origin='lower', interpolation='none', aspect='auto', cmap=cmap, clim=clim) ax.margins(x=0) # Override x-tic labels to display as real time # instead of time bin indices. #if fs is not None: # locs = ax.get_xticks() # newlabels = ["{:0.3f}".format(l/fs-time_offset) for l in locs] # ax.set_xticklabels(newlabels) # TODO: Is there a way the colorbar can overlap part of the image # rather than shift it over? # cbar = plt.colorbar(fraction=0.05) # cbar.set_label('Amplitude') ax.set_xlabel('Time') ax.set_ylabel('Channel') if title: ax.set_title(title) ax_remove_box(ax) return ax
def plot_timeseries(times, values, xlabel='Time', ylabel='Value', legend=None, linestyle='-', linewidth=1, ax=None, title=None, colors=None): ''' Plots a simple timeseries with one line for each pair of time and value vectors. Lines will be auto-colored according to matplotlib defaults. times : list of vectors values : list of vectors xlabel : str ylabel : str legend : list of strings linestyle, linewidth : pass-through options to plt.plot() TODO: expand this doc -jacob 2-17-18 ''' if ax is not None: plt.sca(ax) else: ax = plt.gca() cc = 0 opt = {} for t, v in zip(times, values): if colors is not None: opt = {'color': colors[cc]} plt.plot(t, v, linestyle=linestyle, linewidth=linewidth, **opt) cc += 1 plt.margins(x=0) plt.xlabel(xlabel) plt.ylabel(ylabel) ax.set_xlim([np.min(times), np.max(times)]) if legend: plt.legend(legend) if title: plt.title(title, fontsize=8) ax_remove_box(ax)
def state_gain_plot(modelspec, ax=None, colors=None, clim=None, title=None, **options): state_idx = find_module('state', modelspec) g = modelspec.phi_mean[state_idx]['g'] d = modelspec.phi_mean[state_idx]['d'] ge = modelspec.phi_sem[state_idx]['g'] de = modelspec.phi_sem[state_idx]['d'] MI = modelspec[0]['meta']['state_mod'] state_chans = modelspec[0]['meta']['state_chans'] if ax is not None: plt.sca(ax) else: ax=plt.gca() if d.shape[0] > 1: opt={} for cc in range(d.shape[1]): if colors is not None: opt = {'color': colors[cc]} plt.plot(d[:,cc],'--', **opt) plt.plot(g[:,cc], **opt) else: plt.errorbar(np.arange(len(d[0, :])), d[0, :], de[0, :], color='blue') plt.errorbar(np.arange(len(g[0, :])), g[0, :], ge[0, :], color='red') dz = np.abs(d[0, :] / de[0, :]) gz = np.abs(g[0, :] / ge[0, :]) for i in range(len(gz)): if gz[i] > 2: ax.text(i, g[0, i] + np.sign(g[0, i]) * ge[0, i], state_chans[i], color='red', ha='center', fontsize=6) elif dz[i] > 2: ax.text(i, d[0,i]+np.sign(d[0,i])*de[0,i], state_chans[i], color='blue', ha='center', fontsize=6) #plt.plot(MI) #plt.xticks(np.arange(len(state_chans)), state_chans, fontsize=6) plt.legend(('baseline', 'gain'), frameon=False) plt.plot(np.arange(len(state_chans)),np.zeros(len(state_chans)),'k--', linewidth=0.5) if title: plt.title(title) ax_remove_box(ax)
def update_figure(self): p = self.parent start_bin = int(p.start_time * self.fs) stop_bin = int(p.stop_time * self.fs) d = self.recording[self.signal].as_continuous()[self.keep, start_bin:stop_bin] if self.point: self.axes.imshow(d, aspect='auto', cmap='Greys', interpolation='nearest', origin='lower') self.axes.get_yaxis().set_visible(False) elif self.tiled: self.axes.imshow(d, aspect='auto', origin='lower') else: t = np.linspace(p.start_time, p.stop_time, d.shape[1]) self.axes.plot(t, d.T) if self.channel_names is not None: if len(self.channel_names) > 1: self.axes.legend(self.channel_names, frameon=False) self.axes.set_xlim(p.start_time, p.stop_time) self.axes.set_ylim(ymin=self.ymin, ymax=self.ymax) #self.axes.set_xlim(p.start_time, p.stop_time) self.axes.set_ylabel(self.signal) #self.axes.autoscale(enable=True, axis='x', tight=True) ax_remove_box(self.axes) self.draw() if self.point or self.tiled: tick_labels = self.axes.get_xticklabels() #new_labels = [round((t.get_position()[0]+start_bin)/fs) # if t.get_text() else '' # for t in tick_labels] new_labels = [''] * len(tick_labels) self.axes.set_xticklabels(new_labels) self.draw()
def __init__(self, recording=None, signal='stim', parent=None, *args, **kwargs): MyMplCanvas.__init__(self, *args, **kwargs) if 'mask' in recording.signals: self.recording = recording.apply_mask() else: self.recording = recording self.signal = signal self.signal_obj = self.recording[self.signal] self.fs = self.signal_obj.fs self.parent = parent print("creating canvas: {}".format(signal)) sig_array = self.signal_obj.as_continuous() # Chop off end of array (where it's all nan'd out after processing) # TODO: Make this smarter incase there are intermediate nans? self.max_time = sig_array.shape[-1] / self.recording[self.signal].fs point = (isinstance(self.recording[self.signal], nems.signal.PointProcess)) tiled = (isinstance(self.recording[self.signal], nems.signal.TiledSignal) or 'stim' in self.recording[self.signal].name or 'contrast' in self.recording[self.signal].name) if (not point) and (not tiled): self.ymax = np.nanmax(sig_array) * 1.25 self.ymin = min(0, np.nanmin(sig_array) * 1.25) self.point = point self.tiled = tiled # skip some channels, get names c_count = self.recording[self.signal].shape[0] if self.recording[self.signal].chans is None: channel_names = [''] * c_count else: channel_names = self.recording[self.signal].chans[:c_count] skip_channels = ['baseline'] if channel_names is not None: keep = np.array([(n not in skip_channels) for n in channel_names]) channel_names = [ channel_names[i] for i in range(c_count) if keep[i] ] else: keep = np.ones(c_count, dtype=bool) channel_names = None self.keep = keep self.channel_names = channel_names p = self.parent d = sig_array[self.keep, :] if self.point: self.axes.imshow(d, aspect='auto', cmap='Greys', interpolation='nearest', origin='lower') self.axes.get_yaxis().set_visible(False) elif self.tiled: self.axes.imshow(d, aspect='auto', origin='lower') else: self.axes.plot(d.T) if self.channel_names is not None: if len(self.channel_names) > 1: self.axes.legend(self.channel_names, frameon=False) self.axes.set_ylim(ymin=self.ymin, ymax=self.ymax) self.axes.set_xlim(p.start_time * self.fs, p.stop_time * self.fs) self.axes.set_ylabel(self.signal) ax_remove_box(self.axes) self.draw() tick_labels = self.axes.get_xticklabels() if self.point or self.tiled: new_labels = [''] * len(tick_labels) self.axes.set_xticklabels(new_labels) self.draw() else: # TODO: Still not working... Should turn bins to seconds fmt = tkr.FuncFormatter(self.seconds_formatter()) self.axes.yaxis.set_major_formatter(fmt) self.draw()
def state_vars_psth_all(rec, epoch, psth_name='resp', psth_name2='pred', state_sig='state_raw', ax=None, colors=None, channel=None, decimate_by=1, files_only=False, modelspec=None): # TODO: Does using epochs make sense for these? if ax is not None: plt.sca(ax) newrec = rec.copy() fn = lambda x: x - newrec['pred']._data newrec['error'] = rec['resp'].transform(fn, 'error') fs = rec[psth_name].fs d = rec[psth_name].get_epoch_bounds('PreStimSilence') PreStimSilence = np.mean(np.diff(d)) - 0.5 / fs d = rec[psth_name].get_epoch_bounds('PostStimSilence') if d.size > 0: PostStimSilence = np.min(np.diff(d)) - 0.5 / fs dd = np.diff(d) dd = dd[dd > 0] else: dd = np.array([]) if dd.size > 0: PostStimSilence = np.min(dd) - 0.5 / fs else: PostStimSilence = 0 state_chan_list = rec['state'].chans low = np.zeros([0, 1]) high = np.zeros([0, 1]) lowE = np.zeros([0, 1]) highE = np.zeros([0, 1]) low2 = np.zeros([0, 1]) high2 = np.zeros([0, 1]) _high2 = None limitset = [] if files_only: #state_chan_list =['a','p','PASSIVE_1'] state_chan_list = [ s for s in state_chan_list if (s.startswith('FILE') | s.startswith('ACTIVE') | s.startswith('PASSIVE')) ] for state_chan in state_chan_list: _low, _high = state_mod_split(rec, epoch=epoch, psth_name=psth_name, channel=channel, state_sig=state_sig, state_chan=state_chan) _lowE, _highE = state_mod_split(newrec, epoch=epoch, psth_name='error', channel=channel, state_sig=state_sig, state_chan=state_chan, stat=scipy.stats.sem) if psth_name2 is not None: _low2, _high2 = state_mod_split(rec, epoch=epoch, psth_name=psth_name2, channel=channel, state_sig=state_sig, state_chan=state_chan) gapdur = _low.shape[0] / fs / 10 gap = np.ones([int(np.ceil(fs * gapdur)), 1]) * np.nan pgap = np.ones(_low.shape) * np.nan if files_only: if state_chan == state_chan_list[0]: low = np.concatenate((high, gap, _low, gap), axis=0) lowE = np.concatenate((highE, gap, _lowE, gap), axis=0) high = np.concatenate((high, gap, pgap, gap), axis=0) highE = np.concatenate((highE, gap, pgap, gap), axis=0) if psth_name2 is not None: low2 = np.concatenate((low2, gap, _low2, gap), axis=0) high2 = np.concatenate((high2, gap, pgap, gap), axis=0) current_start = high.shape[0] / fs + gapdur if state_chan.startswith('ACTIVE'): _low = pgap _low2 = pgap else: _low = _high.copy() _low2 = _high2.copy() _high = pgap _high2 = pgap else: current_start = low.shape[0] / fs + gapdur low = np.concatenate((low, gap, _low, gap), axis=0) high = np.concatenate((high, gap, _high, gap), axis=0) lowE = np.concatenate((lowE, gap, _lowE, gap), axis=0) highE = np.concatenate((highE, gap, _highE, gap), axis=0) if psth_name2 is not None: low2 = np.concatenate((low2, gap, _low2, gap), axis=0) high2 = np.concatenate((high2, gap, _high2, gap), axis=0) limitset += [[ current_start + PreStimSilence, current_start + _low.shape[0] / fs - PostStimSilence ]] if decimate_by > 1: low = scipy.signal.decimate(low, q=decimate_by, axis=1) high = scipy.signal.decimate(high, q=decimate_by, axis=1) if psth_name2 is not None: low2 = scipy.signal.decimate(low2, q=decimate_by, axis=1) high2 = scipy.signal.decimate(high2, q=decimate_by, axis=1) fs /= decimate_by tt = np.arange(high.shape[0]) / fs ax.fill_between(tt, low[:, 0] - lowE[:, 0], low[:, 0] + lowE[:, 0], color=fill_colors['passive']) l1, = ax.plot(tt, low, ls='-', lw=1, color=line_colors['passive']) ax.fill_between(tt, high[:, 0] - highE[:, 0], high[:, 0] + highE[:, 0], color=fill_colors['active']) l2, = ax.plot(tt, high, ls='-', lw=1, color=line_colors['active']) if psth_name2 is not None: ax.plot(tt, low2, ls='--', lw=1, color=line_colors['passive']) ax.plot(tt, high2, ls='--', lw=1, color=line_colors['active']) if not files_only: plt.legend((l1, l2), ('Lo', 'Hi')) ax.set_ylabel('sp/sec') ylim = ax.get_ylim() for ls, s in zip(limitset, state_chan_list): if modelspec is not None: sc = modelspec[0]['meta']['state_chans'] mi = modelspec[0]['meta']['state_mod'] sn = "{} ({:.2f})".format(s, mi[sc.index(s)]) else: sn = s ax.plot(ls, [ylim[1], ylim[1]], 'k-', linewidth=2) lc = np.mean(ls) ax.text(lc, ylim[1], sn, ha='center', va='bottom', fontsize=6) ax.set_ylim([ylim[0], ylim[1] * 1.1]) ax_remove_box(ax)
def update_figure(self): self.axes.cla() epochs = self.recording.epochs p = self.parent valid_epochs = epochs[(epochs['start'] >= p.start_time) & (epochs['end'] < p.stop_time)] if valid_epochs.size == 0: print('no valid epochs') # valid_epochs = valid_epochs.append([{'name': 'EXPT', 'start': p.start_time, 'end': p.stop_time}]) return # On each refresh, keep the same keys but reform the lists of indices. self.epoch_groups = {k: [] for k in self.epoch_groups} for i, r in valid_epochs.iterrows(): s = r['start'] e = r['end'] n = r['name'] prefix = n.split('_')[0].split(',')[0].strip(' ').lower() if len(n) < 5: prefix = 'X' if prefix in [ 'prestimsilence', 'poststimsilence', 'reference', 'target' ]: # skip pass elif prefix in self.epoch_groups: self.epoch_groups[prefix].append(i) else: self.epoch_groups[prefix] = [i] colors = [ 'Red', 'Orange', 'Green', 'LightBlue', 'DarkBlue', 'Purple', 'Pink', 'Black', 'Gray' ] for i, g in enumerate(self.epoch_groups): for j in self.epoch_groups[g]: n = valid_epochs['name'][j] s = valid_epochs['start'][j] e = valid_epochs['end'][j] try: n2 = valid_epochs['name'][j + 1] s2 = valid_epochs['start'][j + 1] e2 = valid_epochs['end'][j + 1] except KeyError: # j is already the last epoch in the list pass n2 = n s2 = s e2 = e # If two epochs with the same name overlap, # extend the end of the first to the end of the second # and skip the second epoch. # Same if end goes past next start. if n == n2: if (s2 < e) or (e > s2): e = e2 j += 1 else: pass # Don't plot text boxes outside of plot limits if s < p.start_time: s = p.start_time elif e > p.stop_time: e = p.stop_time x = np.array([s, e]) y = np.array([i, i]) self.axes.plot(x, y, '-', color=colors[i % len(colors)]) #if len(self.epoch_groups[g]) < 5: self.axes.text(s, i, n, va='bottom', fontsize='small', color=colors[i % len(colors)]) self.axes.set_xlim([p.start_time, p.stop_time]) self.axes.set_ylim([-0.5, i + 0.5]) ax_remove_box(self.axes) self.draw()
def update_figure(self): p = self.parent epochs = self.recording.epochs valid_epochs = epochs[(epochs['start'] >= p.start_time) & (epochs['end'] < p.stop_time)] # On each refresh, keep the same keys but reform the lists of indices. self.epoch_groups = {k: [] for k in self.epoch_groups} for i, n, s, e in valid_epochs.itertuples(): prefix = n.split('_')[0] if prefix in self.epoch_groups: self.epoch_groups[prefix].append(i) else: self.epoch_groups[prefix] = [i] # colors = ['Blue', 'Green', 'Yellow', 'Red'] # k = 0 for i, g in enumerate(self.epoch_groups): # c = colors[k] # k += 1 # if k == len(colors): # k = 0 for j in self.epoch_groups[g]: n = valid_epochs['name'][j] s = valid_epochs['start'][j] e = valid_epochs['end'][j] try: n2 = valid_epochs['name'][j + 1] s2 = valid_epochs['start'][j + 1] e2 = valid_epochs['end'][j + 1] except KeyError: # j is already the last epoch in the list pass # If two epochs with the same name overlap, # extend the end of the first to the end of the second # and skip the second epoch. # Same if end goes past next start. if n == n2: if (s2 < e) or (e > s2): e = e2 j += 1 else: pass # Don't plot text boxes outside of plot limits if s < p.start_time: s = p.start_time elif e > p.stop_time: p = p.stop_time x = np.array([s, e]) y = np.array([i, i]) # props = {'facecolor': c, 'alpha': 0.07} # self.axes.text(x[0], y[0], n, va='bottom', # bbox=props) self.axes.plot(x, y, 'k-') self.axes.text(s, i, n, va='bottom') self.axes.hold(True) # i = 0 # for index, e in valid_epochs.iterrows(): # x = np.array([e['start'],e['end']]) # y = np.array([i, i]) # self.axes.plot(x, y, 'k-') # self.axes.text(x[0], y[0], e['name'], va='bottom') # i += 1 # self.axes.hold(True) self.axes.set_xlim([p.start_time, p.stop_time]) self.axes.set_ylim([-0.5, i + 0.5]) self.axes.set_ylabel('epochs') #self.axes.autoscale(enable=True, axis='x', tight=True) ax_remove_box(self.axes) self.axes.hold(False) self.draw() tick_labels = self.axes.get_xticklabels() new_labels = [''] * len(tick_labels) self.axes.set_xticklabels(new_labels) self.draw()
def plot_scatter(sig1, sig2, ax=None, title=None, smoothing_bins=False, channels=0, xlabel=None, ylabel=None, legend=True, text=None, force_square=False, module=None, **options): ''' Uses the channels of sig1 to place points along the x axis, and channels of sig2 for distances along the y axis. If sig1 has one channel but sig2 has multiple channels, then all of sig2's channels will be plotted against the values from sig1. If sig1 has more than 1 channel, then sig2 must have the same number of channels, because XY coordinates will be determined from the same channel of both sig1 and sig2. Optional arguments: ax smoothing_bins: int xlabel ylabel legend module - NEMS module that applies an input-output transformation on data plotted from x to y axes. overlay data with curve from the module ''' if (sig1.nchans > 1) or (sig2.nchans > 1): log.warning('sig1 or sig2 chancount > 1, using chan 0') if ax: plt.sca(ax) ax = plt.gca() m1 = sig1.as_continuous() m2 = sig2.as_continuous() # remove NaNs keepidx = np.isfinite(m1[0, :]) * np.isfinite(m2[0, :]) m1 = m1[channels:(channels + 1), keepidx] m2 = m2[channels:(channels + 1), keepidx] for i in range(m2.shape[0]): if m1.shape[0] > 1: x = m1[[i], :] else: x = m1[[0], :] y = m2[[i], :] if smoothing_bins: # Concatenate and sort s2 = np.append(x, y, 0) s2 = s2[:, s2[0, :].argsort()] # ???? bincount = np.min([smoothing_bins, s2.shape[1]]) T = np.int(np.floor(s2.shape[1] / bincount)) x0 = np.zeros(bincount) y0 = np.zeros(bincount) minx = np.min(x) stepsize = (np.max(x) - minx) / bincount for bb in range(bincount): kk = (x >= minx + bb * stepsize) & (x < minx + (bb + 1) * stepsize) if np.sum(kk): x0[bb] = np.mean(x[kk]) y0[bb] = np.mean(y[kk]) kk = (np.abs(x0) > 0) & (np.abs(y0) > 0) x = x0 y = y0 # s2 = s2[:, 0:(T * bincount)] # s2 = np.reshape(s2, [2, bincount, T]) # s2 = np.mean(s2, 2) # s2 = np.squeeze(s2) # x = s2[0, :] # y = s2[1, :] chan_name = 'Channel {}'.format(i) if not sig2.chans else sig2.chans[i] plt.scatter(x, y, label=chan_name, s=2, color='darkgray') if module is not None: xbounds = ax.get_xbound() plot_nl_io(module, xbounds, ax) xlabel = xlabel if xlabel else sig1.name plt.xlabel(xlabel) ylabel = ylabel if ylabel else sig2.name plt.ylabel(ylabel) if legend and sig2.nchans > 1: plt.legend(loc='lower right') if title: plt.title(title) if text is not None: # Figure out where to align text box axes = plt.gca() ymin, ymax = axes.get_ylim() xmin, xmax = axes.get_xlim() if ymin == ymax: ymax = ymin + 1 if xmin == xmax: xmax = xmin + 1 x_coord = xmin + (xmax - xmin) / 50 y_coord = ymax - (ymax - ymin) / 20 plt.text(x_coord, y_coord, text, verticalalignment='top') if force_square: axes = plt.gca() ymin, ymax = axes.get_ylim() xmin, xmax = axes.get_xlim() axes.set_aspect(abs(xmax - xmin) / abs(ymax - ymin)) ax_remove_box(ax) return ax
def strf_heatmap(modelspec, ax=None, clim=None, show_factorized=True, title='STRF', fs=None, chans=None, wc_idx=0, fir_idx=0, interpolation='none', absolute_value=False, cmap='RdYlBu_r', manual_extent=None, show_cbar=True, **options): """ chans: list if not None, label each row of the strf with the corresponding channel name interpolation: string, tuple if string, passed on as parameter to imshow if tuple, ndimage "zoom" by a factor of (x,y) on each dimension """ if fs is None: try: fs = modelspec.recording['stim'].fs except: pass wcc = _get_wc_coefficients(modelspec, idx=wc_idx) firc = _get_fir_coefficients(modelspec, idx=fir_idx, fs=fs) fir_mod = find_module('fir', modelspec, find_all_matches=True)[fir_idx] if wcc is None and firc is None: log.warn('Unable to generate STRF.') return elif wcc is None and firc is not None: strf = np.array(firc) show_factorized = False elif wcc is not None and firc is None: strf = np.array(wcc).T show_factorized = False elif 'filter_bank' in modelspec[fir_mod]['fn']: wc_coefs = np.array(wcc).T fir_coefs = np.array(firc) bank_count = modelspec[fir_mod]['fn_kwargs']['bank_count'] chan_count = wcc.shape[0] bank_chans = int(chan_count / bank_count) if wc_coefs.shape[1] == fir_coefs.shape[0]: strfs = [ wc_coefs[:, (bank_chans * i):(bank_chans * (i + 1))] @ fir_coefs[(bank_chans * i):(bank_chans * (i + 1)), :] for i in range(bank_count) ] for i in range(bank_count): m = np.max(np.abs(strfs[i])) if m: strfs[i] = strfs[i] / m if i > 0: gap = np.full([strfs[i].shape[0], 1], np.nan) strfs[i] = np.concatenate( (gap, strfs[i] / np.max(np.abs(strfs[i]))), axis=1) strf = np.concatenate(strfs, axis=1) else: strf = fir_coefs show_factorized = False else: wc_coefs = np.array(wcc).T fir_coefs = np.array(firc) if wc_coefs.shape[1] == fir_coefs.shape[0]: strf = wc_coefs @ fir_coefs else: strf = fir_coefs show_factorized = False if not clim: cscale = np.nanmax(np.abs(strf.reshape(-1))) clim = [-cscale, cscale] else: cscale = np.max(np.abs(clim)) if type(interpolation) is str: if interpolation == 'none': pass else: show_factorized = False else: s = strf.shape strf = zoom(strf, interpolation) fs = fs * interpolation[1] interpolation = 'none' if show_factorized: # Never rescale the STRF or CLIM! # The STRF should be the final word and respect input colormap and clim # However: rescaling WC and FIR coefs to make them more visible is ok wc_max = np.nanmax(np.abs(wc_coefs[:])) fir_max = np.nanmax(np.abs(fir_coefs[:])) wc_coefs = wc_coefs * (cscale / wc_max) fir_coefs = fir_coefs * (cscale / fir_max) n_inputs, _ = wc_coefs.shape nchans, ntimes = fir_coefs.shape gap = np.full([nchans + 1, nchans + 1], np.nan) horz_space = np.full([1, ntimes], np.nan) vert_space = np.full([n_inputs, 1], np.nan) top_right = np.concatenate([fir_coefs, horz_space], axis=0) top_left = np.concatenate([wc_coefs, vert_space], axis=1) bot = np.concatenate([top_left, strf], axis=1) top = np.concatenate([gap, top_right], axis=1) everything = np.concatenate([top, bot], axis=0) skip = nchans + 1 else: everything = strf skip = 0 if absolute_value: everything = np.abs(everything) plot_heatmap(everything, xlabel='Lag (s)', ylabel='Channel In', ax=ax, skip=skip, title=title, fs=fs, interpolation=interpolation, cmap=get_setting('FILTER_CMAP'), manual_extent=manual_extent, show_cbar=show_cbar) ax_remove_box(left=True, bottom=True, ticks=True) if chans is not None: for i, c in enumerate(chans): plt.text(0, i + nchans + 1, c, verticalalignment='center')
def state_vars_timeseries(rec, modelspec, ax=None, state_colors=None, decimate_by=1, channel=None): if ax is not None: plt.sca(ax) pred = rec['pred'] resp = rec['resp'] fs = resp.fs chanidx = get_channel_number(resp, channel) r1 = resp.as_continuous()[chanidx, :].T * fs p1 = pred.as_continuous()[chanidx, :].T * fs nnidx = np.isfinite(p1) r1 = r1[nnidx] p1 = p1[nnidx] if decimate_by > 1: r1 = scipy.signal.decimate(r1, q=decimate_by, axis=0) p1 = scipy.signal.decimate(p1, q=decimate_by, axis=0) fs /= decimate_by t = np.arange(len(r1)) / fs plt.plot(t, r1, linewidth=1, color='gray') plt.plot(t, p1, linewidth=1, color='black') print(p1.shape) mmax = np.nanmax(p1) * 0.8 if 'state' in rec.signals.keys(): s = None g = None d = None for m in modelspec: if 'state_dc_gain' in m['fn']: g = np.array(m['phi']['g']) d = np.array(m['phi']['d']) if len(g) < 10: s = ",".join(rec["state"].chans) g_string = np.array2string(g, precision=3) d_string = np.array2string(d, precision=3) s += " g={} d={} ".format(g_string, d_string) else: s = None num_vars = rec['state'].shape[0] ts = rec['state'].as_continuous().copy() if state_colors is None: state_colors = [None] * num_vars offset = -1.25 * mmax for i in range(1, num_vars): st = ts[i, :].T if len(np.unique(st)) == 2: # special, binary variable, keep in one row m = np.array([np.min(st)]) st = np.concatenate((m, st, m)) dinc = np.argwhere(np.diff(st) > 0) ddec = np.argwhere(np.diff(st) < 0) for x0, x1 in zip(dinc, ddec): plt.plot([x0 / fs, x1 / fs], [offset, offset], lw=2, color=state_colors[i - 1]) tstr = "{}".format(rec['state'].chans[i]) plt.text(x0 / fs, offset, tstr, fontsize=6) #print("{} {} {}".format(rec['state'].chans[i], x0/fs, offset)) else: # non-binary variable, plot in own row # figure out gain if g is not None: if g.ndim == 1: tstr = "{} (d={:.3f},g={:.3f})".format( rec['state'].chans[i], d[i], g[i]) else: tstr = "{} (d={:.3f},g={:.3f})".format( rec['state'].chans[i], d[0, i], g[0, i]) else: tstr = "{}".format(rec['state'].chans[i]) if decimate_by > 1: st = scipy.signal.decimate(st[nnidx], q=decimate_by, axis=0) else: st = st[nnidx] st = st / np.nanmax(st) * mmax + offset plt.plot(t, st, linewidth=1, color=state_colors[i - 1]) plt.text(t[0], offset, tstr, fontsize=6) offset -= 1.25 * mmax ax = plt.gca() # plt.text(0.5, 0.9, s, transform=ax.transAxes, # horizontalalignment='center') # if s: # plt.title(s, fontsize=8) plt.xlabel('time (s)') plt.axis('tight') ax_remove_box(ax)
t = "n={}/{}\np={:.3e}\nmd={:.4f}".format(np.sum(si), si.shape[0], p, md) ax.text(tx, ty, t, va='top') ax.set_xlabel('FG-BG gain') ax = axes[0, 2] ax.plot(rdiff[nsi], gdiff[nsi], 'o', color='gray', mec='w', mew=1) ax.plot(rdiff[si], gdiff[si], 'o', color='#83428c', mec='w', mew=1) r, p = pearsonr(rdiff[nsi + si], gdiff[nsi + si]) x = np.polyfit(rdiff, gdiff, 1) x0 = np.array(ax.get_xlim()) y0 = x0 * x[0] + x[1] ax.plot(x0, y0, 'k--') ax_remove_box(ax) ax.set_xlabel('deltaR') ax.set_ylabel('deltaG') ax.set_title('R={:.3f} p={:.4e}'.format(r, p)) ax = axes[1, 0] beta_comp(b_S, f_S, n1='bg_S', n2='fg_S', ax=ax, hist_range=[-bound, bound], highlight=si, title=bs + " (shf)") ax = axes[1, 1]
def state_gain_plot(modelspec, rec, state_sig='state_raw', ax=None, colors=None, clim=None, title=None, **options): state_chan_list = rec[state_sig].chans state_idx = find_module('state', modelspec) g = modelspec.phi_mean[state_idx]['g'].copy() ge = modelspec.phi_sem[state_idx]['g'] if modelspec[state_idx]['fn'] == 'nems.modules.state.state_gain': d = None de = None gainoffset = modelspec[state_idx]['fn_kwargs']['gainoffset'] g += modelspec[state_idx]['fn_kwargs']['gainoffset'] else: d = modelspec.phi_mean[state_idx]['d'] de = modelspec.phi_sem[state_idx]['d'] MI = modelspec[0]['meta']['state_mod'] state_chans = modelspec[0]['meta']['state_chans'] if ax is None: ax = plt.gca() if g.shape[0] > 1: opt = {} for cc in range(g.shape[1]): if colors is not None: opt = {'color': colors[cc]} if d is not None: ax.plot(d[:, cc], '--', **opt) ax.plot(g[:, cc], **opt) else: if d is not None: ax.errorbar(np.arange(len(d[0, :])), d[0, :], de[0, :], color='blue') dz = np.abs(d[0, :] / de[0, :]) ax.errorbar(np.arange(len(g[0, :])), g[0, :], ge[0, :], color='red') gz = np.abs(g[0, :] / ge[0, :]) for i in range(len(gz)): if gz[i] > 2: ax.text(i, g[0, i] + np.sign(g[0, i]) * ge[0, i], state_chans[i], color='red', ha='center', fontsize=6) elif d is not None and dz[i] > 2: ax.text(i, d[0, i] + np.sign(d[0, i]) * de[0, i], state_chans[i], color='blue', ha='center', fontsize=6) #ax.plot(MI) #ax.xticks(np.arange(len(state_chans)), state_chans, fontsize=6) if d is None: ax.legend(state_chan_list, frameon=False) ax.set_ylabel('Gain') ax.plot(np.arange(len(state_chans)), gainoffset * np.ones(len(state_chans)), 'k--', linewidth=0.5) else: ax.legend(('baseline', 'gain'), frameon=False) ax.plot(np.arange(len(state_chans)), np.zeros(len(state_chans)), 'k--', linewidth=0.5) if title: ax.title(title) ax_remove_box(ax)
def test_DRC_with_contrast(ms=30, normalize=True, fs=100, bands=1, percentile=70, n_segments=8, example_batch=289, example_cell='TAR010c-13-1', voc_batch=263, voc_cell='tul034b-b1'): ''' Plot a sample DRC stimulus next to assigned contrast and calculated contrast. ''' loadkey = 'ozgf.fs%d.ch18' % fs nat_plot_bins = 1050 voc_plot_bins = 1110 nat_seconds = np.arange(0, nat_plot_bins) / fs voc_seconds = np.arange(0, voc_plot_bins) / fs recording_uri = generate_recording_uri(cellid=example_cell, batch=example_batch, loadkey=loadkey, stim=True) nat_rec = load_recording(recording_uri) nat_rec = make_contrast_signal(nat_rec, name='continuous', continuous=True, ms=ms, percentile=percentile, bands=bands, normalize=normalize) epochs = nat_rec['resp'].epochs stim_epochs = ep.epoch_names_matching(epochs, 'STIM_') pre_silence = silence_duration(epochs, 'PreStimSilence') post_silence = silence_duration(epochs, 'PostStimSilence') indices = np.arange(20, dtype=np.int32) for s in stim_epochs: row = epochs[epochs.name == s] st = row['start'].values[0] end = row['end'].values[0] stim_start = int((st + pre_silence) * fs) stim_end = int((end - post_silence) * fs) indices = np.append( indices, np.arange(stim_start - 20, stim_end + 20, dtype=np.int32)) nat_stim = nat_rec['stim'].as_continuous()[:, indices][:, :nat_plot_bins] nat_contrast = nat_rec['continuous'].as_continuous( )[:, indices][:, :nat_plot_bins] nat_summed = np.sum(nat_contrast, axis=0) nat_summed /= np.max(nat_summed) # norm 0 to 1 just to match axes voc_rec_uri = generate_recording_uri(cellid=voc_cell, batch=voc_batch, loadkey=loadkey, stim=True) voc_rec = load_recording(voc_rec_uri) voc_rec = make_contrast_signal(voc_rec, name='continuous', continuous=True, ms=ms, percentile=percentile, bands=bands, normalize=normalize) # Force voc and noise to be interleaved for visualization epochs = voc_rec.epochs stim_epochs = ep.epoch_names_matching(epochs, 'STIM_') vocs = sorted([s for s in stim_epochs if '0dB' not in s]) noise = sorted(list(set(stim_epochs) - set(vocs))) indices = np.array([], dtype=np.int32) for v, n in zip(vocs, noise): voc_row = epochs[epochs.name == v] voc_start = int(voc_row['start'].values[0] * fs) voc_end = int(voc_row['end'].values[0] * fs) indices = np.append(indices, np.arange(voc_start, voc_end, dtype=np.int32)) noise_row = epochs[epochs.name == n] noise_start = int(noise_row['start'].values[0] * fs) noise_end = int(noise_row['end'].values[0] * fs) indices = np.append(indices, np.arange(noise_start, noise_end, dtype=np.int32)) voc_stim = voc_rec['stim'].as_continuous()[:, indices][:, :voc_plot_bins] voc_contrast = voc_rec['continuous'].as_continuous( )[:, indices][:, :voc_plot_bins] voc_summed = np.sum(voc_contrast, axis=0) voc_summed /= np.max(voc_summed) # norm 0 to 1 just to match axes fig, ((a2, a3), (a5, a6), (a8, a9)) = plt.subplots(3, 2, figsize=(3.5, 4)) # Natural Sound plt.sca(a2) #plt.title('Nat. Sound') plt.imshow(nat_stim, cmap=spectrogram_cmap, aspect='auto', origin='lower') a2.get_xaxis().set_visible(False) a2.get_yaxis().set_visible(False) ax_remove_box(a2) #plt.ylabel('Freq. Channel') plt.sca(a5) #plt.title('Contrast') plt.imshow(nat_contrast, aspect='auto', origin='lower', cmap=contrast_cmap) a5.get_xaxis().set_visible(False) a5.get_yaxis().set_visible(False) ax_remove_box(a5) plt.sca(a8) #plt.title('Summed') plt.plot(nat_seconds, nat_summed, color='black', linewidth=0.5) a8.set_ylim(-0.1, 1.1) a8.get_yaxis().set_visible(False) #plt.xlabel('Time (s)') #plt.ylabel('Summed Contrast (A.U.)') a8.set_xlim(nat_seconds.min(), nat_seconds.max()) ax_remove_box(a8) # Voc in noise plt.sca(a3) #plt.title('Voc. in Noise') plt.imshow(voc_stim, cmap=spectrogram_cmap, aspect='auto', origin='lower') a3.get_xaxis().set_visible(False) a3.get_yaxis().set_visible(False) ax_remove_box(a3) plt.sca(a6) #plt.title('Continuous Calculated Contrast') plt.imshow(voc_contrast, aspect='auto', origin='lower', cmap=contrast_cmap) a6.get_xaxis().set_visible(False) a6.get_yaxis().set_visible(False) ax_remove_box(a6) plt.sca(a9) #plt.title('Summed') plt.plot(voc_seconds, voc_summed, color='black', linewidth=0.5) a9.set_ylim(-0.1, 1.1) #a9.get_yaxis().tick_right() a9.get_yaxis().set_visible(False) a9.set_xlim(voc_seconds.min(), voc_seconds.max()) ax_remove_box(a9) fig2 = plt.figure() text = ("top: spectrogram\n" "middle: contrast\n" "bottom: summed contrast\n" "left: 289, right: 263\n" "x: Time (s)\n" "summed 0 to 1, contrast arb., spec freq increasing") plt.text(0.1, 0.5, text) return fig, fig2
def plot_heatmap(array, xlabel='Time', ylabel='Channel', ax=None, cmap=None, clim=None, skip=0, title=None, fs=None, interpolation='none', manual_extent=None, show_cbar=True, fontsize=7, **options): ''' A wrapper for matplotlib's plt.imshow() to ensure consistent formatting. ''' if ax is not None: plt.sca(ax) else: ax = plt.gca() if cmap is None: cmap = get_setting('WEIGHTS_CMAP') # Make sure array is converted to ndarray if passed as list array = np.array(array) if clim is None: mmax = np.nanmax(np.abs(array.reshape(-1))) clim = [-mmax, mmax] if manual_extent is not None: extent = manual_extent elif fs is not None: extent = [ 0.5 / fs, (array.shape[1] + 0.5) / fs, 0.5, array.shape[0] + 0.5 ] else: extent = None plt.imshow(array, aspect='auto', origin='lower', cmap=cmap, clim=clim, interpolation=interpolation, extent=extent) # Force integer tick labels, skipping gaps #y, x = array.shape #plt.xticks(np.arange(skip, x), np.arange(0, x-skip)) #plt.xticklabels(np.arange(0, x-skip)) #plt.yticks(np.arange(skip, y), np.arange(0, y-skip)) #plt.yticklabels(np.arange(0, y-skip)) plt.xlabel(xlabel) plt.ylabel(ylabel) if show_cbar: # Set the color bar cbar = plt.colorbar() cbar.ax.tick_params(labelsize=fontsize) cbar.ax.yaxis.set_major_locator(plt.MaxNLocator(3)) cbar.set_label('Gain', fontsize=fontsize) cbar.outline.set_edgecolor('white') if title is not None: plt.title(title) ax_remove_box(ax) return ax