def plotCorr(ssubs, compare, rng=slice(0, T), ca_or_spikes='ca', labels=None, title='Shapes obtained on', colors=[orange, blue, vermillon], ls='-'): """ map based on spatial distance, removed neurons are set to corr=0 """ def foo(ssub, comp, idx=None): N, T = comp.shape cor = np.zeros((N, len(dsls))) for i, ds in enumerate(dsls): if len(ssub[ds][0]) == len(comp): cor[:, i] = np.nan_to_num([ np.corrcoef( ssub[ds][0 if ca_or_spikes == 'ca' else 2][n, rng], comp[n, rng])[0, 1] for n in range(N) ]) else: # necessary if component has been removed on small batch for n, k in enumerate(idx): if k is not None: # component removed -> keep corr=0 cor[n, i] = np.nan_to_num( np.corrcoef( ssub[ds][0 if ca_or_spikes == 'ca' else 2][ k, rng], comp[n, rng])[0, 1]) return cor for i, ssub in enumerate(ssubs): cor = foo(ssub, compare, [None, idxH, idxQ][i]) plt.plot(dsls, np.mean(cor, 0), ls=ls, lw=4, c=colors[i], label=None if labels is None else labels[i], clip_on=False) noclip( plt.errorbar(dsls, np.mean(cor, 0), yerr=np.std(cor, 0) / np.sqrt(len(cor)), lw=3, capthick=2, fmt='o', c=colors[i], clip_on=False)) plt.xlabel('Spatial decimation') simpleaxis(plt.gca()) plt.xticks(dsls, ['1', '', '', '', '', '8x8', '', '16x16', '24x24', '32x32']) plt.yticks(*[np.arange(np.round(plt.ylim()[1], 1), plt.ylim()[0], -.2)] * 2) plt.xlim(dsls[0], dsls[-1]) plt.ylim(.25, 1) plt.ylabel('Correlation w/ undecimated $C_1$/$S_1$', y=.42, labelpad=1) plt.subplots_adjust(.13, .15, .94, .96) if labels is not None: lg = plt.legend(loc=(-.01, -.01), title=title) return lg
def adjust(): plt.xticks(*[[0, 5, 10]] * 2) plt.yticks(*[[1, 1.01], ['1.00', 1.01]]) plt.xlim(0, 8.5) plt.ylim(l, u) plt.xlabel('Wall time [s]') simpleaxis(plt.gca()) plt.subplots_adjust(.1, .155, .99, .99)
def plotCorr(ssubs, compare, dsls, labels=None, colors=[cyan, orange, green], ca_or_spikes='ca', ls='-'): def foo(ssub, comp): N, T = comp.shape cor = np.zeros((N, len(dsls))) for i, ds in enumerate(dsls): if len(ssub[ds][0]) == len(comp): cor[:, i] = np.nan_to_num([ np.corrcoef(ssub[ds][0 if ca_or_spikes == 'ca' else 2][n], comp[n])[0, 1] for n in range(N) ]) else: # necessary if update_spatial_components removed a component mapIdx = [ np.argmax([np.corrcoef(s, tC)[0, 1] for tC in comp]) for s in ssub[ds][0 if ca_or_spikes == 'ca' else 2] ] for n in range(len(ssub[ds][0])): cor[mapIdx[n], i] = np.nan_to_num( np.corrcoef( ssub[ds][0 if ca_or_spikes == 'ca' else 2][n], comp[mapIdx[n]])[0, 1]) return cor for i, ssub in enumerate(ssubs): cor = foo(ssub, compare) plt.plot(dsls, np.mean(cor, 0), ls=ls, lw=4, c=colors[i], label=None if labels is None else labels[i], clip_on=False) noclip( plt.errorbar(dsls, np.mean(cor, 0), yerr=np.std(cor, 0) / np.sqrt(len(cor)), lw=3, capthick=2, fmt='o', c=colors[i], clip_on=False)) plt.xlabel('Spatial decimation') simpleaxis(plt.gca()) plt.xticks(dsls, ['1', '', '', '', '', '8x8', '', '16x16', '24x24', '32x32']) plt.yticks(*[[.4, .6, .8, 1.0]] * 2) plt.xlim(dsls[0], dsls[-1]) plt.ylabel('Correlation w/ undecimated $C_1$/$S_1$', y=.42, labelpad=1) if labels is not None: lg = plt.legend(loc=(0, 0)) return lg
def plotChoice(w): # arg which 3b or 3e offers = eval('offers' + w) ratio = eval('ratio' + w) pB = eval('pB' + w) v = np.array([offers[1:-1, 0], ratio * offers[1:-1, 1]]).T logratio = np.log(v[:, 0] / v[:, 1]) x = np.transpose(s[w], (1, 0, 2, 3)) rts = np.array([[get_t(a) for a in x[r]] for r in range(len(x))], dtype=int) choice = rts[:, :, 1] > rts[:, :, 2] pBsim = np.mean(choice, 0) # plot pl.figure() pl.plot(logratio, pB[1:-1], '--', marker='s', ms=30, mfc='w', mec=col[0], mew=3, color=col[0], zorder=10, clip_on=False) pl.plot(logratio, pBsim[1:-1], color=col[1], marker='s', ms=22, mec=col[1], clip_on=False, zorder=10) pl.plot([logratio[0] - .5, logratio[-1] + .5], [pB[0], pB[-1]], 's', ms=30, mfc='w', mec=col[0], mew=3, color=col[0], zorder=10, clip_on=False) pl.plot([logratio[0] - .5, logratio[-1] + .5], [pBsim[0], pBsim[-1]], 's', color=col[1], ms=22, mec=col[1], clip_on=False, zorder=10) pl.xticks([logratio[0] - .5, -1, 0, 1, logratio[-1] + .5, ], [r"$-\infty$", -1, 0, 1, r"$\infty$"]) tt = pl.gca().get_xaxis().majorTicks for i in range(len(tt)): tt[i].set_pad(17) pl.yticks([0, .5, 1], [0, 50, 100]) pl.xlim(logratio[0] - .65, logratio[-1] + .5) pl.ylim(0, 1) pl.xlabel('log(V(B)/V(A))', labelpad=-2) pl.ylabel('B choice [\%]', labelpad=-15) pl.text(.05, .83, 'A=' + str(ratio) + 'B', transform=pl.gca().transAxes) pl.subplots_adjust(.17, .23, .96, .95) simpleaxis(pl.gca()) # broken axis pl.plot([logratio[0] - .27, logratio[0] - .23], [0, 0], c='w', lw=2, clip_on=False, zorder=11) pl.plot([logratio[0] - .29, logratio[0] - .25], [-.02, .02], c='k', lw=2, clip_on=False, zorder=11) pl.plot([logratio[0] - .25, logratio[0] - .21], [-.02, .02], c='k', lw=2, clip_on=False, zorder=11) pl.plot([logratio[-1] + .28, logratio[-1] + .24], [0, 0], c='w', lw=2, clip_on=False, zorder=11) pl.plot([logratio[-1] + .3, logratio[-1] + .26], [.02, -.02], c='k', lw=2, clip_on=False, zorder=11) pl.plot([logratio[-1] + .26, logratio[-1] + .22], [.02, -.02], c='k', lw=2, clip_on=False, zorder=11) if savefig: pl.savefig('Padoa' + w + '.pdf', dpi=600) else: pl.show() return choice
def plotCorr(ssub, ssub0, r=pearsonr, ds1phase=[1, 2, 3, 4], loc=(.1, .01)): def foo(ssub, comp, dsls=dsls, ca_or_spikes='ca'): N, T = comp.shape cor = np.zeros((N, len(dsls))) * np.nan for i, ds in enumerate(dsls): if len(ssub[ds][0]) == len(comp): cor[:, i] = np.array( [r(ssub[ds][0 if ca_or_spikes == 'ca' else 2][n], comp[n])[0] for n in range(N)]) else: # necessary if update_spatial_components removed a component mapIdx = [np.argmax([np.corrcoef(s, tC)[0, 1] for tC in comp]) for s in ssub[ds][0 if ca_or_spikes == 'ca' else 2]] for n in range(len(ssub[ds][0])): cor[mapIdx[n], i] = np.array( r(ssub[ds][0 if ca_or_spikes == 'ca' else 2][n], comp[mapIdx[n]])[0]) return np.nan_to_num(cor) cor = foo(ssub0, trueC, ds1phase) l1, = plt.plot(ds1phase, np.median(np.nan_to_num(cor), 0), lw=4, c=cyan, label='1 phase imaging') IQRfill(cor, dsls, cyan) cor = foo(ssub, trueC) l2, = plt.plot(dsls, np.median(cor, 0), lw=4, c=orange, label='2 phase imaging') IQRfill(cor, dsls, orange) cor = foo(ssub0, trueSpikes, ds1phase, ca_or_spikes='spikes') plt.plot(ds1phase, np.median(cor, 0), lw=4, c=cyan, ls='--') IQRfill(cor, dsls, cyan, ls='--', hatch='///') cor = foo(ssub, trueSpikes, ca_or_spikes='spikes') plt.plot(dsls, np.median(cor, 0), lw=4, c=orange, ls='--') IQRfill(cor, dsls, orange, ls='--', hatch='\\\\\\') l3, = plt.plot([0, 1], [-1, -1], lw=4, c='k', label='denoised') l4, = plt.plot([0, 1], [-1, -1], lw=4, c='k', ls='--', label='deconvolved') plt.xlabel('Spatial decimation') simpleaxis(plt.gca()) plt.xticks(dsls, ['1', '', '', '', '', '8x8', '', '16x16', '24x24', '32x32']) plt.ylim(.3, 1) plt.yticks( *[np.round(np.arange(np.round(plt.ylim()[1], 1), plt.ylim()[0], -.2), 1)] * 2) plt.xlim(dsls[0], dsls[-1]) plt.ylabel('Correlation w/ undecimated $C_1$/$S_1$', y=.42, labelpad=1) plt.legend(handles=[l3, l4, l1, l2], loc=loc, ncol=1) plt.subplots_adjust(.13, .15, .94, .96) return l1, l2, l3, l4
def plot_trace(n=0, lg=False): plt.plot(trueC[n], c=col[2], clip_on=False, zorder=5, label='Truth') plt.plot(solution, c=col[0], clip_on=False, zorder=7, label='Estimate') plt.plot(y, c=col[7], alpha=.7, clip_on=False, zorder=-10, label='Data') if lg: plt.legend(frameon=False, ncol=3, loc=(.1, .62), columnspacing=.8) spks = np.append(0, solution[1:] - g * solution[:-1]) plt.text(800, 2.2, 'Correlation: %.3f' % (np.corrcoef(trueSpikes[n], spks)[0, 1]), size=24) plt.gca().set_xticklabels([]) simpleaxis(plt.gca()) plt.ylim(0, 2.85) plt.xlim(0, 1500) plt.yticks([0, 2], [0, 2]) plt.xticks([300, 600, 900, 1200], ['', ''])
def cb(y, active_set, counter, current): solution = np.empty(len(y)) for i, (v, w, f, l) in enumerate(active_set): solution[f:f + l] = (v if i else max(v, 0)) / w * g**np.arange(l) color = y.copy() ax1.plot(solution, c='k', zorder=-11, lw=1.3, clip_on=False) ax1.scatter(np.arange(len(y)), solution, s=40, cmap=plt.cm.Spectral, c=color, clip_on=False, zorder=11) ax1.scatter([np.arange(len(y))[current]], [solution[current]], s=120, lw=2.5, marker='+', color='b', clip_on=False, zorder=11) for a in active_set[::2]: ax1.axvspan(a[2], a[2] + a[3], alpha=0.1, color='k', zorder=-11) for x in np.where(trueSpikes)[0]: ax1.plot([x, x], [0, 2.3], lw=1.5, c='r', zorder=-12) ax1.set_xlim((0, len(y) - .5)) ax1.set_ylim((0, 2.3)) simpleaxis(ax1) ax1.set_xticks([]) ax1.set_yticks([]) ax1.set_ylabel('Fluorescence') for i, s in enumerate(np.r_[[0], solution[1:] - g * solution[:-1]]): ax2.plot([i, i], [0, s], c='k', zorder=-11, lw=1.4, clip_on=False) ax2.scatter(np.arange(len(y)), np.r_[[0], solution[1:] - g * solution[:-1]], s=40, cmap=plt.cm.Spectral, c=color, clip_on=False, zorder=11) ax2.scatter([np.arange(len(y))[current]], [np.r_[[0], solution[1:] - g * solution[:-1]][current]], s=120, lw=2.5, marker='+', color='b', clip_on=False, zorder=11) for a in active_set[::2]: ax2.axvspan(a[2], a[2] + a[3], alpha=0.1, color='k', zorder=-11) for x in np.where(trueSpikes)[0]: ax2.plot([x, x], [0, 1.55], lw=1.5, c='r', zorder=-12) ax2.set_xlim((0, len(y) - .5)) ax2.set_ylim((0, 1.55)) simpleaxis(ax2) ax2.set_xticks([]) ax2.set_yticks([]) ax2.set_xlabel('Time', labelpad=35, x=.5) ax2.set_ylabel('Spikes') plt.subplots_adjust(left=0.032, right=.995, top=.995, bottom=0.19, hspace=0.22) fig.canvas.draw() if save_figs: plt.savefig('video/%03d.pdf' % counter) # import time # time.sleep(.03) ax1.clear() ax2.clear()
def cb(y, active_set, counter, current): solution = np.empty(len(y)) for v, w, f, l in active_set: solution[f:f + l] = max(v, 0) / w * g**np.arange(l) plt.figure(figsize=(3, 3)) color = y.copy() plt.plot(solution, c='k', zorder=-11, lw=1.2) plt.scatter(np.arange(len(y)), solution, s=60, cmap=plt.cm.Spectral, c=color, clip_on=False, zorder=11) plt.scatter([np.arange(len(y))[current]], [solution[current]], s=200, lw=2.5, marker='+', color='b', clip_on=False, zorder=11) for a in active_set[::2]: plt.axvspan(a[2], a[2] + a[3], alpha=0.1, color='k', zorder=-11) for x in np.where(trueSpikes)[0]: plt.plot([x, x], [0, 1.65], lw=1.5, c='r', zorder=-12) plt.xlim((0, len(y) - .5)) plt.ylim((0, 1.65)) simpleaxis(plt.gca()) plt.xticks([]) plt.yticks([]) # plt.tight_layout() if save_figs: plt.savefig('fig/%d.pdf' % counter) plt.show()
def plotTrace(lg=False): fig = plt.figure(figsize=(10, 9)) fig.add_axes([.13, .7, .86, .29]) plt.plot(c, c=col[0], label='L1') plt.plot(c_t, c=col[1], label='Thresh.') plt.plot(trueC[0], c=col[2], lw=3, label='Truth', zorder=-5) plt.plot(y, c=col[7], alpha=.7, zorder=-10, label='Data') if lg: plt.legend(frameon=False, ncol=4, loc=(.05, .82)) plt.gca().set_xticklabels([]) simpleaxis(plt.gca()) plt.yticks([0, int(y.max())], [0, int(y.max())]) plt.xticks(range(150, 500, 150), [''] * 3) plt.ylabel('Fluor.') plt.xlim(0, 452) fig.add_axes([.13, .39, .86, .29]) for i, ss in enumerate(s[:500]): if ss > 1e-2: plt.plot([i, i], [2.5, 2.5 + ss], c=col[0], zorder=10) plt.plot([0, 450], [2.5, 2.5], c=col[0], zorder=10) for i, ss in enumerate(s_t[:500]): if ss > 1e-2: plt.plot([i, i], [1.25, 1.25 + ss], c=col[1], zorder=10) plt.plot([0, 450], [1.25, 1.25], c=col[1], zorder=10) for i, ss in enumerate(trueSpikes[0, :500]): if ss > 1e-2: plt.plot([i, i], [0, ss], c=col[2], clip_on=False, zorder=10) plt.plot([0, 450], [0, 0], c=col[2], clip_on=False, zorder=10) plt.gca().set_xticklabels([]) simpleaxis(plt.gca()) plt.yticks([0, 1.25, 2.5], ['Truth', 'Thresh.', 'L1']) for tick in plt.gca().yaxis.get_major_ticks(): tick.label1.set_verticalalignment('bottom') plt.xticks(range(150, 500, 150), [''] * 3) plt.ylim(0, 3.5) plt.xlim(0, 452) fig.add_axes([.13, .08, .86, .29]) for i, r in enumerate(res): for rr in r: plt.plot([rr, rr], [.1 * i - .04, .1 * i + .04], c='k') for rr in np.where(trueSpikes[0])[0]: plt.plot([rr, rr], [-.08, -.16], c='r') plt.gca().set_xticklabels([]) simpleaxis(plt.gca()) plt.yticks([0, .5, 1], [0, 0.5, 1.0]) plt.xticks(range(0, 500, 150), [0, 5, 10, '']) plt.ylim(-.2, 1.1) plt.xlim(0, 452) plt.ylabel(r'$s_{\min}$') plt.xlabel('Time [s]', labelpad=-10) plt.show()
pl.figure() for c in range(8): pl.plot(Ratels.mean(axis=0)[c], color=col[c]) pl.xlabel('Time from target [ms]', labelpad=0) pl.ylabel('Firing rate [Hz]', labelpad=10) pl.xticks([0, 250 / step, 500 / step], [0, 250, 500]) pl.yticks([0, 20, 40, 60, 80], [0, 20, 40, 60, 80]) pl.xlim(0, 500 / step) pl.ylim(0, 64) pl.plot(smooth_spikes([step * 100 / 1000. * (.1 + .65 * np.exp(-(t * step - 500) ** 2 / tau_r ** 2)) for t in range(int(1000 / step))], 40, .2)[250 / step:750 / step], color='black', zorder=-1, lw=2) pl.subplots_adjust(.18, .21, .945, .99) simpleaxis(pl.gca()) if savefig: pl.savefig('Sohn.pdf', dpi=600) else: pl.show() pl.figure() pl.plot(Ratels.mean(axis=0)[3], color=col[3]) l, = pl.plot(Ratels3.mean(axis=0)[3], '--', color=col[3]) l.set_dashes([10, 10]) l, = pl.plot(Ratels4.mean(axis=0)[3], ':', color=col[3]) l.set_dashes([3, 3]) pl.xlabel('Time from target [ms]', labelpad=0) pl.ylabel('Firing rate [Hz]', labelpad=10) pl.xticks([0, 250 / step, 500 / step], [0, 250, 500])
fig1.subplots_adjust(.04, .1, 1, 1) if savefig: pl.savefig('u.pdf', dpi=600, transparent=True) else: pl.show() # Spikes spikes, u = cfn.run(net.W, step, 400, 100, 20, 2, 20, 2) pl.figure() for i in xrange(8): map(lambda a: pl.plot([a, a], [.95 + i, .05 + i], c=col[i]), np.where(spikes[:, i] == 1)[0]) pl.yticks(np.arange(.4, 8), ['0L', '0R', '1L', '1R', '2L', '2R', '3L', '3R']) pl.gca().yaxis.set_tick_params(width=0) pl.xticks(np.arange(0, len(u) + 1, 40 / step), [0, 40, 80]) simpleaxis(pl.gca()) pl.tight_layout(pad=.03) if savefig: pl.savefig('spikes.pdf', dpi=600, transparent=True) else: pl.show() def f(W, x, ref): y = (1 - 1. / 2000) * x + 1. / 2000\ * ((1 + ref) * np.dot(W, x * (x > 0)) - ref * x * (x > 0)) y[net.r] = 1 return y x = np.zeros((1201, net.K)) np.random.seed(2)
color=col[1], mec=col[1], capthick=2, clip_on=False, zorder=11) pl.xticks(range(len(offers3e)), [ str(a[0]) + ':' + str(a[1]) for i, a in enumerate(offers3e)]) tt = pl.gca().get_xaxis().majorTicks for i in range(0, len(tt), 2): tt[i].set_pad(40) for i in range(1, len(tt), 2): tt[i].set_pad(13) pl.yticks([0, 10], [0, 10]) pl.xlim(-.3, len(offers3e) - .7) pl.ylim(0, 14) pl.xlabel('offers (\#B:\#A)', labelpad=-2) pl.ylabel('Firing rate [Hz]', y=.45) pl.text(.05, .83, 'A=' + str(ratio3e) + 'B', transform=pl.gca().transAxes) pl.subplots_adjust(.17, .28, .98, .99) simpleaxis(pl.gca()) lg = pl.legend( ['Experiment', 'Model'], bbox_to_anchor=(.63, .85), handlelength=2, frameon=False) if savefig: pl.savefig('Padoa3e2.pdf', dpi=600) else: pl.show() x = s['3b'] y = 2 * np.transpose(x[:, :, :int(500 / step), :2].sum(axis=2), (1, 0, 2)) pl.figure() pl.plot(range(1, len(offers3b) - 1), f3b[1:-1], '--', color=col[0], zorder=10, clip_on=False) pl.errorbar(range(len(offers3b)), f3b, yerr=fSEM3b, fmt='o', ms=30, color=col[0], mfc='w', mec=col[0], mew=3, capthick=2, zorder=10, clip_on=False) pl.errorbar(range(len(offers3b)), y.mean(axis=0).sum(axis=1) / pop3b,
zorder=11) pl.xticks(range(len(offers3e)), [str(a[0]) + ':' + str(a[1]) for i, a in enumerate(offers3e)]) tt = pl.gca().get_xaxis().majorTicks for i in range(0, len(tt), 2): tt[i].set_pad(40) for i in range(1, len(tt), 2): tt[i].set_pad(13) pl.yticks([0, 10], [0, 10]) pl.xlim(-.3, len(offers3e) - .7) pl.ylim(0, 14) pl.xlabel('offers (\#B:\#A)', labelpad=-2) pl.ylabel('Firing rate [Hz]', y=.45) pl.text(.05, .83, 'A=' + str(ratio3e) + 'B', transform=pl.gca().transAxes) pl.subplots_adjust(.17, .28, .98, .99) simpleaxis(pl.gca()) lg = pl.legend(['Experiment', 'Model'], bbox_to_anchor=(.63, .85), handlelength=2, frameon=False) if savefig: pl.savefig('Padoa3e2.pdf', dpi=600) else: pl.show() x = s['3b'] y = 2 * np.transpose(x[:, :, :int(500 / step), :2].sum(axis=2), (1, 0, 2)) pl.figure() pl.plot(range(1, len(offers3b) - 1), f3b[1:-1],
fig1.subplots_adjust(.04, .1, 1, 1) if savefig: pl.savefig('u.pdf', dpi=600, transparent=True) else: pl.show() # Spikes spikes, u = cfn.run(net.W, step, 400, 100, 20, 2, 20, 2) pl.figure() for i in xrange(8): map(lambda a: pl.plot([a, a], [.95 + i, .05 + i], c=col[i]), np.where(spikes[:, i] == 1)[0]) pl.yticks(np.arange(.4, 8), ['0L', '0R', '1L', '1R', '2L', '2R', '3L', '3R']) pl.gca().yaxis.set_tick_params(width=0) pl.xticks(np.arange(0, len(u) + 1, 40 / step), [0, 40, 80]) simpleaxis(pl.gca()) pl.tight_layout(pad=.03) if savefig: pl.savefig('spikes.pdf', dpi=600, transparent=True) else: pl.show() def f(W, x, ref): y = (1 - 1. / 2000) * x + 1. / 2000\ * ((1 + ref) * np.dot(W, x * (x > 0)) - ref * x * (x > 0)) y[net.r] = 1 return y x = np.zeros((1201, net.K))
if i == 1: l.set_dashes([14, 10]) plt.plot(z * ssubLR[4][2][neuronId] / scale - .7 * ub, alpha=1., c=cyan, zorder=-10) plt.ylim(-.7 * ub, ub) plt.legend(frameon=False, ncol=len(dsls), loc=[.073, .775], columnspacing=14.2) plt.xticks(range(0, T, ticksep * fps), ['', '', '']) tmp = [500, 300, 100, 100][k] plt.plot([0, 0], [0, 1], c='k', lw=5, clip_on=False, zorder=11) plt.text(7, 1, '100\%', verticalalignment='center') simpleaxis(plt.gca()) plt.gca().spines['left'].set_visible(False) plt.yticks([]) for i, ds in enumerate(dsls): ax = fig.add_axes([[.17, .467, .78][i], .838 - .244 * k, .035, .2625]) ss = downscale_local_mean(gaussian(shapes[neuronId], 1), (ds, ds)).repeat(ds, 0).repeat(ds, 1) ax.imshow(ss[map(lambda a: slice(*a), boxes[neuronId])], cmap='hot', interpolation='nearest') ax.axis('off') ax.text( 1.25, .5, '%.3f' % (np.corrcoef(ssub[ds][0][neuronId], ssub[1][0][neuronId])[0, 1]),
def plotChoice(w): # arg which 3b or 3e offers = eval('offers' + w) ratio = eval('ratio' + w) pB = eval('pB' + w) v = np.array([offers[1:-1, 0], ratio * offers[1:-1, 1]]).T logratio = np.log(v[:, 0] / v[:, 1]) x = np.transpose(s[w], (1, 0, 2, 3)) rts = np.array([[get_t(a) for a in x[r]] for r in range(len(x))], dtype=int) choice = rts[:, :, 1] > rts[:, :, 2] pBsim = np.mean(choice, 0) # plot pl.figure() pl.plot(logratio, pB[1:-1], '--', marker='s', ms=30, mfc='w', mec=col[0], mew=3, color=col[0], zorder=10, clip_on=False) pl.plot(logratio, pBsim[1:-1], color=col[1], marker='s', ms=22, mec=col[1], clip_on=False, zorder=10) pl.plot([logratio[0] - .5, logratio[-1] + .5], [pB[0], pB[-1]], 's', ms=30, mfc='w', mec=col[0], mew=3, color=col[0], zorder=10, clip_on=False) pl.plot([logratio[0] - .5, logratio[-1] + .5], [pBsim[0], pBsim[-1]], 's', color=col[1], ms=22, mec=col[1], clip_on=False, zorder=10) pl.xticks([ logratio[0] - .5, -1, 0, 1, logratio[-1] + .5, ], [r"$-\infty$", -1, 0, 1, r"$\infty$"]) tt = pl.gca().get_xaxis().majorTicks for i in range(len(tt)): tt[i].set_pad(17) pl.yticks([0, .5, 1], [0, 50, 100]) pl.xlim(logratio[0] - .65, logratio[-1] + .5) pl.ylim(0, 1) pl.xlabel('log(V(B)/V(A))', labelpad=-2) pl.ylabel('B choice [\%]', labelpad=-15) pl.text(.05, .83, 'A=' + str(ratio) + 'B', transform=pl.gca().transAxes) pl.subplots_adjust(.17, .23, .96, .95) simpleaxis(pl.gca()) # broken axis pl.plot([logratio[0] - .27, logratio[0] - .23], [0, 0], c='w', lw=2, clip_on=False, zorder=11) pl.plot([logratio[0] - .29, logratio[0] - .25], [-.02, .02], c='k', lw=2, clip_on=False, zorder=11) pl.plot([logratio[0] - .25, logratio[0] - .21], [-.02, .02], c='k', lw=2, clip_on=False, zorder=11) pl.plot([logratio[-1] + .28, logratio[-1] + .24], [0, 0], c='w', lw=2, clip_on=False, zorder=11) pl.plot([logratio[-1] + .3, logratio[-1] + .26], [.02, -.02], c='k', lw=2, clip_on=False, zorder=11) pl.plot([logratio[-1] + .26, logratio[-1] + .22], [.02, -.02], c='k', lw=2, clip_on=False, zorder=11) if savefig: pl.savefig('Padoa' + w + '.pdf', dpi=600) else: pl.show() return choice
[[net.R4Pi2(net.get_policy2(step, i, T), net.pstart_state) for i in S] for T in Tls]) perf = (rew - R0) / (Rmax - R0) perf[0] = 0 np.save('results/performance', perf) pl.figure() errorfill(range(len(perf)), np.mean(perf, axis=1), yerr=np.std(perf, axis=1) / np.sqrt(len(perf[0]))) pl.xticks([0, 200, 400], [0, 200, 400]) pl.yticks([0, .5, 1.0], [0, .5, 1.0]) pl.xlim([0, 400]) pl.ylim([0, 1]) pl.xlabel('Time [ms]') pl.ylabel('Performance') simpleaxis(pl.gca()) pl.tight_layout(0) pl.savefig('performance.pdf', dpi=600) # learning via parallel sampling eta, T, r0 = .01, 1000, 0 try: perf = np.load('results/learn_performance.npy') except IOError: try: rew = np.load('results/learn.npz')['R'] except IOError: res = np.array([ net.parallel_sampling_keepU(step, eta, run,
fig = plt.figure(figsize=(6.5, 6)) for ssub, col in ((activityDSlr, cyan), (activityDS, orange)): cor, mx = foo(ssub, activityDS[1], shapes) if col == cyan: mapIdx = mx plt.plot(dsls, np.median(cor, 0), lw=3, c=col, clip_on=False, label='1 phase\n imaging' if col == cyan else '2 phase\n imaging') IQRfill(cor, dsls, col) plt.xlabel('Spatial decimation') ax = plt.gca() simpleaxis(ax) ax.patch.set_visible(False) plt.xticks(dsls, ['1x1', '', '', '4x4', '6x6', '8x8', '12x12']) plt.ylim(.55, 1) plt.yticks(*[[.6, .8, 1.]] * 2) plt.ylabel('Correlation', labelpad=0) plt.legend(loc=(.01, .25), ncol=1) plt.subplots_adjust(.145, .15, .885, .97) ax2 = ax.twinx() ax2.set_zorder(-1) ax2.spines['top'].set_visible(False) z = np.mean((activityDS[1].T.dot(shapes.reshape(len(shapes), -1)) - data.reshape(activity.shape[1], -1))**2) ax2.plot(dsls, [ np.mean((activityDSlr[ds].T.dot(shapesDSlr[ds].repeat(ds, 1).repeat( ds, 2).reshape(len(shapes), -1)) - data.reshape(activity.shape[1], -1))