def plot_model_fits(y, yhat, fitparams, err=None, palettes=[gpal, bpal], save=False, cdf=True, bw=.01): """ main plotting function for displaying model fit predictions over data """ sns.set(style='darkgrid', rc={'figure.facecolor':'white'}, font_scale=1.5) # extract model and fit info from fitparams nlevels = fitparams['nlevels'] ssd, nssd, nss, nss_per, ssd_ix = fitparams.ssd_info quantiles = fitparams.quantiles # make colors and labels palettes = colors.get_cpals(aslist=True)[:nlevels] clrs = [pal(2) for pal in palettes] lbls = get_plot_labels(fitparams) f, axes = plt.subplots(nlevels, 3, figsize=(14, 5*nlevels), sharey=True) # make two dimensional for iterating 1st dim axes = axes.reshape(nlevels, 3) plot_acc = scurves if nssd==1: plot_acc = plot_accuracy y_dat = unpack_vector(y, fitparams, bw=bw) y_kde = unpack_vector(y, fitparams, bw=bw, kde=True) yhat_dat = unpack_vector(yhat, fitparams, bw=bw) if err is not None: y_err = unpack_vector(err, fitparams, bw=bw) sc_err = [e[0] for e in y_err] qp_err = [[e[1], e[2]] for e in y_err] else: sc_err, qp_err = [[None]*nlevels]*2 for i, (ax1, ax2, ax3) in enumerate(axes): accdata = [y_dat[i][0], yhat_dat[i][0]] qpdata = [y_dat[i], yhat_dat[i]] plot_acc(accdata, err=sc_err[i], ssd=ssd[i], colors=clrs[i], labels=lbls[i], ax=ax1) plot_quantiles(qpdata, err=qp_err[i], quantiles=quantiles, colors=clrs[i], axes=[ax2,ax3], kde=y_kde[i], bw=bw) axes = format_axes(axes) if save: plt.savefig(fitparams['model_id']+'.png', dpi=600) if fitparams['fit_on']=='subjects' and save: plt.close('all')
def plot_traces(self, style='HL', ax=None, label_x=True, save=False): cpals = get_cpals() sns.set(style='white', font_scale=1.5) if ax is None: f, ax = plt.subplots(1, figsize=(5, 5)) tr = self.onset titl = describe_model(self.depends_on) if style not in ['HL', 'HML']: gmu = [ggt.mean(axis=1) for ggt in self.go_traces] nmu = [ngt.mean(axis=1) for ngt in self.ng_traces] else: go_counts = [self.go_traces[i].shape[1] for i in range(len(self.go_traces))] ng_counts = [self.ng_traces[i].shape[1] for i in range(len(self.ng_traces))] go_lo_ri = randint(0, high=go_counts[0]) go_hi_ri = randint(0, high=go_counts[-1]) ng_lo_ri = randint(0, high=ng_counts[0]) ng_hi_ri = randint(0, high=ng_counts[-1]) ng_hi_ri = np.argmin(self.ng_traces[-1].iloc[-1, :]) glow = self.go_traces[0].iloc[:, go_lo_ri].dropna().values #pd.concat(self.go_traces[1:3], axis=1).iloc[:,0] ghi = self.go_traces[-1].iloc[:, go_hi_ri].dropna().values nglow = self.ng_traces[0].iloc[:, ng_lo_ri].dropna().values nghi = self.ng_traces[-1].iloc[:, ng_hi_ri].dropna().values # return nglow, nghi gmu = [glow, ghi] nmu = [nglow, nghi] tr = [tr[0], tr[-1]] gc = ["#40ac5b", '#10ac1d'] nc = ['#dc3c3c', '#d61b1b'] #gc = cpals['gpal'](len(gmu)) #nc = cpals['rpal'](len(nmu)) gx = [tr[i] + np.arange(len(gmu[i])) * self.dt for i in range(len(gmu))] nx = [tr[i] + np.arange(len(nmu[i])) * self.dt for i in range(len(nmu))] ls = ['-', '-'] for i in range(len(gmu)): ax.plot(gx[i], gmu[i], linestyle=ls[i], lw=1, color=gc[i]) ax.fill_between(gx[i], gmu[i], y2=0, lw=1, color=gc[i], alpha=.25) for i in range(len(nmu)): ax.plot(nx[i], nmu[i], linestyle=ls[i], lw=1, color=nc[i]) ax.fill_between(nx[i], nmu[i], y2=0, lw=1, color=nc[i], alpha=.25) ax.set_ylim(0, self.p['a'].max() * 1.01) ax.set_xlim(gx[-1].min() * .98, nx[0].max() * 1.05) if label_x: ax.set_xlabel('Time', fontsize=26) ax.set_xticklabels([]) ax.set_yticklabels([]) ax.set_ylabel('$\\theta_{G}$', fontsize=30) sns.despine() plt.tight_layout() if save: plt.savefig('_'.join([titl, self.decay, 'traces.png']), dpi=300) plt.savefig('_'.join([titl, self.decay, 'traces.svg']), format='svg', rasterized=True) return ax
#!/usr/local/bin/env python from __future__ import division import sys from copy import deepcopy import pandas as pd import numpy as np import matplotlib.pyplot as plt import seaborn as sns from radd.tools import colors, analyze from IPython.display import display, Latex import warnings warnings.simplefilter('ignore', np.RankWarning) warnings.filterwarnings("ignore", module="matplotlib") cdict = colors.get_cpals('all') rpal = cdict['rpal'] bpal = cdict['bpal'] gpal = cdict['gpal'] ppal = cdict['ppal'] heat = cdict['heat'] cool = cdict['cool'] slate = cdict['slate'] sns.set(style='darkgrid', rc={'figure.facecolor':'white'}, font_scale=1.2) def plot_model_fits(y, yhat, fitparams, err=None, palettes=[gpal, bpal], save=False, cdf=True, bw=.01): """ main plotting function for displaying model fit predictions over data """ sns.set(style='darkgrid', rc={'figure.facecolor':'white'}, font_scale=1.5) # extract model and fit info from fitparams nlevels = fitparams['nlevels'] ssd, nssd, nss, nss_per, ssd_ix = fitparams.ssd_info