示例#1
0
def plotCorr(ssubs,
             compare,
             rng=slice(0, T),
             ca_or_spikes='ca',
             labels=None,
             title='Shapes obtained on',
             colors=[orange, blue, vermillon],
             ls='-'):
    """ map based on spatial distance, removed neurons are set to corr=0 """
    def foo(ssub, comp, idx=None):
        N, T = comp.shape
        cor = np.zeros((N, len(dsls)))
        for i, ds in enumerate(dsls):
            if len(ssub[ds][0]) == len(comp):
                cor[:, i] = np.nan_to_num([
                    np.corrcoef(
                        ssub[ds][0 if ca_or_spikes == 'ca' else 2][n, rng],
                        comp[n, rng])[0, 1] for n in range(N)
                ])
            else:  # necessary if component has been removed on small batch
                for n, k in enumerate(idx):
                    if k is not None:  # component removed -> keep corr=0
                        cor[n, i] = np.nan_to_num(
                            np.corrcoef(
                                ssub[ds][0 if ca_or_spikes == 'ca' else 2][
                                    k, rng], comp[n, rng])[0, 1])
        return cor

    for i, ssub in enumerate(ssubs):
        cor = foo(ssub, compare, [None, idxH, idxQ][i])
        plt.plot(dsls,
                 np.mean(cor, 0),
                 ls=ls,
                 lw=4,
                 c=colors[i],
                 label=None if labels is None else labels[i],
                 clip_on=False)
        noclip(
            plt.errorbar(dsls,
                         np.mean(cor, 0),
                         yerr=np.std(cor, 0) / np.sqrt(len(cor)),
                         lw=3,
                         capthick=2,
                         fmt='o',
                         c=colors[i],
                         clip_on=False))
    plt.xlabel('Spatial decimation')
    simpleaxis(plt.gca())
    plt.xticks(dsls,
               ['1', '', '', '', '', '8x8', '', '16x16', '24x24', '32x32'])
    plt.yticks(*[np.arange(np.round(plt.ylim()[1], 1),
                           plt.ylim()[0], -.2)] * 2)
    plt.xlim(dsls[0], dsls[-1])
    plt.ylim(.25, 1)
    plt.ylabel('Correlation w/ undecimated $C_1$/$S_1$', y=.42, labelpad=1)
    plt.subplots_adjust(.13, .15, .94, .96)
    if labels is not None:
        lg = plt.legend(loc=(-.01, -.01), title=title)
        return lg
示例#2
0
def adjust():
    plt.xticks(*[[0, 5, 10]] * 2)
    plt.yticks(*[[1, 1.01], ['1.00', 1.01]])
    plt.xlim(0, 8.5)
    plt.ylim(l, u)
    plt.xlabel('Wall time [s]')
    simpleaxis(plt.gca())
    plt.subplots_adjust(.1, .155, .99, .99)
示例#3
0
def plotCorr(ssubs,
             compare,
             dsls,
             labels=None,
             colors=[cyan, orange, green],
             ca_or_spikes='ca',
             ls='-'):
    def foo(ssub, comp):
        N, T = comp.shape
        cor = np.zeros((N, len(dsls)))
        for i, ds in enumerate(dsls):
            if len(ssub[ds][0]) == len(comp):
                cor[:, i] = np.nan_to_num([
                    np.corrcoef(ssub[ds][0 if ca_or_spikes == 'ca' else 2][n],
                                comp[n])[0, 1] for n in range(N)
                ])
            else:  # necessary if update_spatial_components removed a component
                mapIdx = [
                    np.argmax([np.corrcoef(s, tC)[0, 1] for tC in comp])
                    for s in ssub[ds][0 if ca_or_spikes == 'ca' else 2]
                ]
                for n in range(len(ssub[ds][0])):
                    cor[mapIdx[n], i] = np.nan_to_num(
                        np.corrcoef(
                            ssub[ds][0 if ca_or_spikes == 'ca' else 2][n],
                            comp[mapIdx[n]])[0, 1])
        return cor

    for i, ssub in enumerate(ssubs):
        cor = foo(ssub, compare)
        plt.plot(dsls,
                 np.mean(cor, 0),
                 ls=ls,
                 lw=4,
                 c=colors[i],
                 label=None if labels is None else labels[i],
                 clip_on=False)
        noclip(
            plt.errorbar(dsls,
                         np.mean(cor, 0),
                         yerr=np.std(cor, 0) / np.sqrt(len(cor)),
                         lw=3,
                         capthick=2,
                         fmt='o',
                         c=colors[i],
                         clip_on=False))
    plt.xlabel('Spatial decimation')
    simpleaxis(plt.gca())
    plt.xticks(dsls,
               ['1', '', '', '', '', '8x8', '', '16x16', '24x24', '32x32'])
    plt.yticks(*[[.4, .6, .8, 1.0]] * 2)
    plt.xlim(dsls[0], dsls[-1])
    plt.ylabel('Correlation w/ undecimated $C_1$/$S_1$', y=.42, labelpad=1)
    if labels is not None:
        lg = plt.legend(loc=(0, 0))
        return lg
def plotChoice(w):  # arg which 3b or 3e
    offers = eval('offers' + w)
    ratio = eval('ratio' + w)
    pB = eval('pB' + w)
    v = np.array([offers[1:-1, 0], ratio * offers[1:-1, 1]]).T
    logratio = np.log(v[:, 0] / v[:, 1])
    x = np.transpose(s[w], (1, 0, 2, 3))
    rts = np.array([[get_t(a) for a in x[r]]
                    for r in range(len(x))], dtype=int)
    choice = rts[:, :, 1] > rts[:, :, 2]
    pBsim = np.mean(choice, 0)
    # plot
    pl.figure()
    pl.plot(logratio, pB[1:-1], '--', marker='s', ms=30,  mfc='w',
            mec=col[0], mew=3, color=col[0], zorder=10, clip_on=False)
    pl.plot(logratio, pBsim[1:-1], color=col[1], marker='s', ms=22,
            mec=col[1], clip_on=False, zorder=10)
    pl.plot([logratio[0] - .5, logratio[-1] + .5], [pB[0], pB[-1]], 's', ms=30,
            mfc='w', mec=col[0], mew=3, color=col[0], zorder=10, clip_on=False)
    pl.plot([logratio[0] - .5, logratio[-1] + .5], [pBsim[0], pBsim[-1]], 's',
            color=col[1], ms=22, mec=col[1], clip_on=False, zorder=10)
    pl.xticks([logratio[0] - .5, -1, 0, 1, logratio[-1] + .5, ],
              [r"$-\infty$", -1, 0, 1, r"$\infty$"])
    tt = pl.gca().get_xaxis().majorTicks
    for i in range(len(tt)):
        tt[i].set_pad(17)
    pl.yticks([0, .5, 1], [0, 50, 100])
    pl.xlim(logratio[0] - .65, logratio[-1] + .5)
    pl.ylim(0, 1)
    pl.xlabel('log(V(B)/V(A))', labelpad=-2)
    pl.ylabel('B choice [\%]', labelpad=-15)
    pl.text(.05, .83, 'A=' + str(ratio) + 'B', transform=pl.gca().transAxes)
    pl.subplots_adjust(.17, .23, .96, .95)
    simpleaxis(pl.gca())
    # broken axis
    pl.plot([logratio[0] - .27, logratio[0] - .23],
            [0, 0], c='w', lw=2, clip_on=False, zorder=11)
    pl.plot([logratio[0] - .29, logratio[0] - .25],
            [-.02, .02], c='k', lw=2, clip_on=False, zorder=11)
    pl.plot([logratio[0] - .25, logratio[0] - .21],
            [-.02, .02], c='k', lw=2, clip_on=False, zorder=11)
    pl.plot([logratio[-1] + .28, logratio[-1] + .24],
            [0, 0], c='w', lw=2, clip_on=False, zorder=11)
    pl.plot([logratio[-1] + .3, logratio[-1] + .26],
            [.02, -.02], c='k', lw=2, clip_on=False, zorder=11)
    pl.plot([logratio[-1] + .26, logratio[-1] + .22],
            [.02, -.02], c='k', lw=2, clip_on=False, zorder=11)
    if savefig:
        pl.savefig('Padoa' + w + '.pdf', dpi=600)
    else:
        pl.show()
    return choice
示例#5
0
def plotCorr(ssub, ssub0, r=pearsonr, ds1phase=[1, 2, 3, 4], loc=(.1, .01)):
    def foo(ssub, comp, dsls=dsls, ca_or_spikes='ca'):
        N, T = comp.shape
        cor = np.zeros((N, len(dsls))) * np.nan
        for i, ds in enumerate(dsls):
            if len(ssub[ds][0]) == len(comp):
                cor[:, i] = np.array(
                    [r(ssub[ds][0 if ca_or_spikes == 'ca' else 2][n],
                       comp[n])[0] for n in range(N)])
            else:  # necessary if update_spatial_components removed a component
                mapIdx = [np.argmax([np.corrcoef(s, tC)[0, 1] for tC in comp])
                          for s in ssub[ds][0 if ca_or_spikes == 'ca' else 2]]
                for n in range(len(ssub[ds][0])):
                    cor[mapIdx[n], i] = np.array(
                        r(ssub[ds][0 if ca_or_spikes == 'ca' else 2][n],
                          comp[mapIdx[n]])[0])
        return np.nan_to_num(cor)

    cor = foo(ssub0, trueC, ds1phase)
    l1, = plt.plot(ds1phase, np.median(np.nan_to_num(cor), 0), lw=4, c=cyan,
                   label='1 phase imaging')
    IQRfill(cor, dsls, cyan)

    cor = foo(ssub, trueC)
    l2, = plt.plot(dsls, np.median(cor, 0), lw=4, c=orange, label='2 phase imaging')
    IQRfill(cor, dsls, orange)

    cor = foo(ssub0, trueSpikes, ds1phase, ca_or_spikes='spikes')
    plt.plot(ds1phase, np.median(cor, 0), lw=4, c=cyan, ls='--')
    IQRfill(cor, dsls, cyan, ls='--', hatch='///')

    cor = foo(ssub, trueSpikes, ca_or_spikes='spikes')
    plt.plot(dsls, np.median(cor, 0), lw=4, c=orange, ls='--')
    IQRfill(cor, dsls, orange, ls='--', hatch='\\\\\\')

    l3, = plt.plot([0, 1], [-1, -1], lw=4, c='k', label='denoised')
    l4, = plt.plot([0, 1], [-1, -1], lw=4, c='k', ls='--', label='deconvolved')

    plt.xlabel('Spatial decimation')
    simpleaxis(plt.gca())
    plt.xticks(dsls, ['1', '', '', '', '', '8x8', '', '16x16', '24x24', '32x32'])
    plt.ylim(.3, 1)
    plt.yticks(
        *[np.round(np.arange(np.round(plt.ylim()[1], 1), plt.ylim()[0], -.2), 1)] * 2)
    plt.xlim(dsls[0], dsls[-1])
    plt.ylabel('Correlation w/ undecimated $C_1$/$S_1$', y=.42, labelpad=1)
    plt.legend(handles=[l3, l4, l1, l2], loc=loc, ncol=1)
    plt.subplots_adjust(.13, .15, .94, .96)
    return l1, l2, l3, l4
示例#6
0
def plot_trace(n=0, lg=False):
    plt.plot(trueC[n], c=col[2], clip_on=False, zorder=5, label='Truth')
    plt.plot(solution, c=col[0], clip_on=False, zorder=7, label='Estimate')
    plt.plot(y, c=col[7], alpha=.7, clip_on=False, zorder=-10, label='Data')
    if lg:
        plt.legend(frameon=False, ncol=3, loc=(.1, .62), columnspacing=.8)
    spks = np.append(0, solution[1:] - g * solution[:-1])
    plt.text(800,
             2.2,
             'Correlation: %.3f' % (np.corrcoef(trueSpikes[n], spks)[0, 1]),
             size=24)
    plt.gca().set_xticklabels([])
    simpleaxis(plt.gca())
    plt.ylim(0, 2.85)
    plt.xlim(0, 1500)
    plt.yticks([0, 2], [0, 2])
    plt.xticks([300, 600, 900, 1200], ['', ''])
示例#7
0
def cb(y, active_set, counter, current):
    solution = np.empty(len(y))
    for i, (v, w, f, l) in enumerate(active_set):
        solution[f:f + l] = (v if i else max(v, 0)) / w * g**np.arange(l)
    color = y.copy()
    ax1.plot(solution, c='k', zorder=-11, lw=1.3, clip_on=False)
    ax1.scatter(np.arange(len(y)), solution, s=40, cmap=plt.cm.Spectral,
                c=color, clip_on=False, zorder=11)
    ax1.scatter([np.arange(len(y))[current]], [solution[current]],
                s=120, lw=2.5, marker='+', color='b', clip_on=False, zorder=11)
    for a in active_set[::2]:
        ax1.axvspan(a[2], a[2] + a[3], alpha=0.1, color='k', zorder=-11)
    for x in np.where(trueSpikes)[0]:
        ax1.plot([x, x], [0, 2.3], lw=1.5, c='r', zorder=-12)
    ax1.set_xlim((0, len(y) - .5))
    ax1.set_ylim((0, 2.3))
    simpleaxis(ax1)
    ax1.set_xticks([])
    ax1.set_yticks([])
    ax1.set_ylabel('Fluorescence')
    for i, s in enumerate(np.r_[[0], solution[1:] - g * solution[:-1]]):
        ax2.plot([i, i], [0, s], c='k', zorder=-11, lw=1.4, clip_on=False)
    ax2.scatter(np.arange(len(y)), np.r_[[0], solution[1:] - g * solution[:-1]],
                s=40, cmap=plt.cm.Spectral,   c=color, clip_on=False, zorder=11)
    ax2.scatter([np.arange(len(y))[current]], [np.r_[[0], solution[1:] - g * solution[:-1]][current]],
                s=120, lw=2.5, marker='+', color='b', clip_on=False, zorder=11)
    for a in active_set[::2]:
        ax2.axvspan(a[2], a[2] + a[3], alpha=0.1, color='k', zorder=-11)
    for x in np.where(trueSpikes)[0]:
        ax2.plot([x, x], [0, 1.55], lw=1.5, c='r', zorder=-12)
    ax2.set_xlim((0, len(y) - .5))
    ax2.set_ylim((0, 1.55))
    simpleaxis(ax2)
    ax2.set_xticks([])
    ax2.set_yticks([])
    ax2.set_xlabel('Time', labelpad=35, x=.5)
    ax2.set_ylabel('Spikes')
    plt.subplots_adjust(left=0.032, right=.995, top=.995, bottom=0.19, hspace=0.22)
    fig.canvas.draw()
    if save_figs:
        plt.savefig('video/%03d.pdf' % counter)
    # import time
    # time.sleep(.03)
    ax1.clear()
    ax2.clear()
示例#8
0
def cb(y, active_set, counter, current):
    solution = np.empty(len(y))
    for v, w, f, l in active_set:
        solution[f:f + l] = max(v, 0) / w * g**np.arange(l)
    plt.figure(figsize=(3, 3))
    color = y.copy()
    plt.plot(solution, c='k', zorder=-11, lw=1.2)
    plt.scatter(np.arange(len(y)), solution, s=60, cmap=plt.cm.Spectral,
                c=color, clip_on=False, zorder=11)
    plt.scatter([np.arange(len(y))[current]], [solution[current]],
                s=200, lw=2.5, marker='+', color='b', clip_on=False, zorder=11)
    for a in active_set[::2]:
        plt.axvspan(a[2], a[2] + a[3], alpha=0.1, color='k', zorder=-11)
    for x in np.where(trueSpikes)[0]:
        plt.plot([x, x], [0, 1.65], lw=1.5, c='r', zorder=-12)
    plt.xlim((0, len(y) - .5))
    plt.ylim((0, 1.65))
    simpleaxis(plt.gca())
    plt.xticks([])
    plt.yticks([])
    # plt.tight_layout()
    if save_figs:
        plt.savefig('fig/%d.pdf' % counter)
    plt.show()
示例#9
0
def plotTrace(lg=False):
    fig = plt.figure(figsize=(10, 9))
    fig.add_axes([.13, .7, .86, .29])
    plt.plot(c, c=col[0], label='L1')
    plt.plot(c_t, c=col[1], label='Thresh.')
    plt.plot(trueC[0], c=col[2], lw=3, label='Truth', zorder=-5)
    plt.plot(y, c=col[7], alpha=.7, zorder=-10, label='Data')
    if lg:
        plt.legend(frameon=False, ncol=4, loc=(.05, .82))
    plt.gca().set_xticklabels([])
    simpleaxis(plt.gca())
    plt.yticks([0, int(y.max())], [0, int(y.max())])
    plt.xticks(range(150, 500, 150), [''] * 3)
    plt.ylabel('Fluor.')
    plt.xlim(0, 452)

    fig.add_axes([.13, .39, .86, .29])
    for i, ss in enumerate(s[:500]):
        if ss > 1e-2:
            plt.plot([i, i], [2.5, 2.5 + ss], c=col[0], zorder=10)
        plt.plot([0, 450], [2.5, 2.5], c=col[0], zorder=10)
    for i, ss in enumerate(s_t[:500]):
        if ss > 1e-2:
            plt.plot([i, i], [1.25, 1.25 + ss], c=col[1], zorder=10)
        plt.plot([0, 450], [1.25, 1.25], c=col[1], zorder=10)
    for i, ss in enumerate(trueSpikes[0, :500]):
        if ss > 1e-2:
            plt.plot([i, i], [0, ss], c=col[2], clip_on=False, zorder=10)
        plt.plot([0, 450], [0, 0], c=col[2], clip_on=False, zorder=10)
    plt.gca().set_xticklabels([])
    simpleaxis(plt.gca())
    plt.yticks([0, 1.25, 2.5], ['Truth', 'Thresh.', 'L1'])
    for tick in plt.gca().yaxis.get_major_ticks():
        tick.label1.set_verticalalignment('bottom')
    plt.xticks(range(150, 500, 150), [''] * 3)
    plt.ylim(0, 3.5)
    plt.xlim(0, 452)

    fig.add_axes([.13, .08, .86, .29])
    for i, r in enumerate(res):
        for rr in r:
            plt.plot([rr, rr], [.1 * i - .04, .1 * i + .04], c='k')
    for rr in np.where(trueSpikes[0])[0]:
        plt.plot([rr, rr], [-.08, -.16], c='r')
    plt.gca().set_xticklabels([])
    simpleaxis(plt.gca())
    plt.yticks([0, .5, 1], [0, 0.5, 1.0])
    plt.xticks(range(0, 500, 150), [0, 5, 10, ''])
    plt.ylim(-.2, 1.1)
    plt.xlim(0, 452)
    plt.ylabel(r'$s_{\min}$')
    plt.xlabel('Time [s]', labelpad=-10)
    plt.show()
pl.figure()
for c in range(8):
    pl.plot(Ratels.mean(axis=0)[c], color=col[c])
pl.xlabel('Time from target [ms]', labelpad=0)
pl.ylabel('Firing rate [Hz]', labelpad=10)
pl.xticks([0, 250 / step, 500 / step], [0, 250, 500])
pl.yticks([0, 20, 40, 60, 80], [0, 20, 40, 60, 80])
pl.xlim(0, 500 / step)
pl.ylim(0, 64)
pl.plot(smooth_spikes([step * 100 / 1000. *
                       (.1 + .65 * np.exp(-(t * step - 500) ** 2 / tau_r ** 2))
                       for t in range(int(1000 / step))], 40, .2)[250 / step:750 / step],
        color='black', zorder=-1, lw=2)
pl.subplots_adjust(.18, .21, .945, .99)
simpleaxis(pl.gca())
if savefig:
    pl.savefig('Sohn.pdf', dpi=600)
else:
    pl.show()


pl.figure()
pl.plot(Ratels.mean(axis=0)[3], color=col[3])
l, = pl.plot(Ratels3.mean(axis=0)[3], '--', color=col[3])
l.set_dashes([10, 10])
l, = pl.plot(Ratels4.mean(axis=0)[3], ':', color=col[3])
l.set_dashes([3, 3])
pl.xlabel('Time from target [ms]', labelpad=0)
pl.ylabel('Firing rate [Hz]', labelpad=10)
pl.xticks([0, 250 / step, 500 / step], [0, 250, 500])
fig1.subplots_adjust(.04, .1, 1, 1)
if savefig:
    pl.savefig('u.pdf', dpi=600, transparent=True)
else:
    pl.show()

# Spikes
spikes, u = cfn.run(net.W, step, 400, 100, 20, 2, 20, 2)
pl.figure()
for i in xrange(8):
    map(lambda a: pl.plot([a, a], [.95 + i, .05 + i],
                          c=col[i]), np.where(spikes[:, i] == 1)[0])
pl.yticks(np.arange(.4, 8), ['0L', '0R', '1L', '1R', '2L', '2R', '3L', '3R'])
pl.gca().yaxis.set_tick_params(width=0)
pl.xticks(np.arange(0, len(u) + 1, 40 / step), [0, 40, 80])
simpleaxis(pl.gca())
pl.tight_layout(pad=.03)
if savefig:
    pl.savefig('spikes.pdf', dpi=600, transparent=True)
else:
    pl.show()


def f(W, x, ref):
    y = (1 - 1. / 2000) * x + 1. / 2000\
        * ((1 + ref) * np.dot(W, x * (x > 0)) - ref * x * (x > 0))
    y[net.r] = 1
    return y

x = np.zeros((1201, net.K))
np.random.seed(2)
            color=col[1], mec=col[1], capthick=2, clip_on=False, zorder=11)
pl.xticks(range(len(offers3e)), [
          str(a[0]) + ':' + str(a[1]) for i, a in enumerate(offers3e)])
tt = pl.gca().get_xaxis().majorTicks
for i in range(0, len(tt), 2):
    tt[i].set_pad(40)
for i in range(1, len(tt), 2):
    tt[i].set_pad(13)
pl.yticks([0, 10], [0, 10])
pl.xlim(-.3, len(offers3e) - .7)
pl.ylim(0, 14)
pl.xlabel('offers (\#B:\#A)', labelpad=-2)
pl.ylabel('Firing rate [Hz]', y=.45)
pl.text(.05, .83, 'A=' + str(ratio3e) + 'B', transform=pl.gca().transAxes)
pl.subplots_adjust(.17, .28, .98, .99)
simpleaxis(pl.gca())
lg = pl.legend(
    ['Experiment', 'Model'], bbox_to_anchor=(.63, .85), handlelength=2, frameon=False)
if savefig:
    pl.savefig('Padoa3e2.pdf', dpi=600)
else:
    pl.show()

x = s['3b']
y = 2 * np.transpose(x[:, :, :int(500 / step), :2].sum(axis=2), (1, 0, 2))
pl.figure()
pl.plot(range(1, len(offers3b) - 1),
        f3b[1:-1], '--', color=col[0], zorder=10, clip_on=False)
pl.errorbar(range(len(offers3b)), f3b, yerr=fSEM3b, fmt='o', ms=30, color=col[0],
            mfc='w', mec=col[0], mew=3, capthick=2, zorder=10, clip_on=False)
pl.errorbar(range(len(offers3b)), y.mean(axis=0).sum(axis=1) / pop3b,
示例#13
0
            zorder=11)
pl.xticks(range(len(offers3e)),
          [str(a[0]) + ':' + str(a[1]) for i, a in enumerate(offers3e)])
tt = pl.gca().get_xaxis().majorTicks
for i in range(0, len(tt), 2):
    tt[i].set_pad(40)
for i in range(1, len(tt), 2):
    tt[i].set_pad(13)
pl.yticks([0, 10], [0, 10])
pl.xlim(-.3, len(offers3e) - .7)
pl.ylim(0, 14)
pl.xlabel('offers (\#B:\#A)', labelpad=-2)
pl.ylabel('Firing rate [Hz]', y=.45)
pl.text(.05, .83, 'A=' + str(ratio3e) + 'B', transform=pl.gca().transAxes)
pl.subplots_adjust(.17, .28, .98, .99)
simpleaxis(pl.gca())
lg = pl.legend(['Experiment', 'Model'],
               bbox_to_anchor=(.63, .85),
               handlelength=2,
               frameon=False)
if savefig:
    pl.savefig('Padoa3e2.pdf', dpi=600)
else:
    pl.show()

x = s['3b']
y = 2 * np.transpose(x[:, :, :int(500 / step), :2].sum(axis=2), (1, 0, 2))
pl.figure()
pl.plot(range(1,
              len(offers3b) - 1),
        f3b[1:-1],
fig1.subplots_adjust(.04, .1, 1, 1)
if savefig:
    pl.savefig('u.pdf', dpi=600, transparent=True)
else:
    pl.show()

# Spikes
spikes, u = cfn.run(net.W, step, 400, 100, 20, 2, 20, 2)
pl.figure()
for i in xrange(8):
    map(lambda a: pl.plot([a, a], [.95 + i, .05 + i], c=col[i]),
        np.where(spikes[:, i] == 1)[0])
pl.yticks(np.arange(.4, 8), ['0L', '0R', '1L', '1R', '2L', '2R', '3L', '3R'])
pl.gca().yaxis.set_tick_params(width=0)
pl.xticks(np.arange(0, len(u) + 1, 40 / step), [0, 40, 80])
simpleaxis(pl.gca())
pl.tight_layout(pad=.03)
if savefig:
    pl.savefig('spikes.pdf', dpi=600, transparent=True)
else:
    pl.show()


def f(W, x, ref):
    y = (1 - 1. / 2000) * x + 1. / 2000\
        * ((1 + ref) * np.dot(W, x * (x > 0)) - ref * x * (x > 0))
    y[net.r] = 1
    return y


x = np.zeros((1201, net.K))
示例#15
0
 if i == 1:
     l.set_dashes([14, 10])
 plt.plot(z * ssubLR[4][2][neuronId] / scale - .7 * ub,
          alpha=1.,
          c=cyan,
          zorder=-10)
 plt.ylim(-.7 * ub, ub)
 plt.legend(frameon=False,
            ncol=len(dsls),
            loc=[.073, .775],
            columnspacing=14.2)
 plt.xticks(range(0, T, ticksep * fps), ['', '', ''])
 tmp = [500, 300, 100, 100][k]
 plt.plot([0, 0], [0, 1], c='k', lw=5, clip_on=False, zorder=11)
 plt.text(7, 1, '100\%', verticalalignment='center')
 simpleaxis(plt.gca())
 plt.gca().spines['left'].set_visible(False)
 plt.yticks([])
 for i, ds in enumerate(dsls):
     ax = fig.add_axes([[.17, .467, .78][i], .838 - .244 * k, .035, .2625])
     ss = downscale_local_mean(gaussian(shapes[neuronId], 1),
                               (ds, ds)).repeat(ds, 0).repeat(ds, 1)
     ax.imshow(ss[map(lambda a: slice(*a), boxes[neuronId])],
               cmap='hot',
               interpolation='nearest')
     ax.axis('off')
     ax.text(
         1.25,
         .5,
         '%.3f' %
         (np.corrcoef(ssub[ds][0][neuronId], ssub[1][0][neuronId])[0, 1]),
示例#16
0
def plotChoice(w):  # arg which 3b or 3e
    offers = eval('offers' + w)
    ratio = eval('ratio' + w)
    pB = eval('pB' + w)
    v = np.array([offers[1:-1, 0], ratio * offers[1:-1, 1]]).T
    logratio = np.log(v[:, 0] / v[:, 1])
    x = np.transpose(s[w], (1, 0, 2, 3))
    rts = np.array([[get_t(a) for a in x[r]] for r in range(len(x))],
                   dtype=int)
    choice = rts[:, :, 1] > rts[:, :, 2]
    pBsim = np.mean(choice, 0)
    # plot
    pl.figure()
    pl.plot(logratio,
            pB[1:-1],
            '--',
            marker='s',
            ms=30,
            mfc='w',
            mec=col[0],
            mew=3,
            color=col[0],
            zorder=10,
            clip_on=False)
    pl.plot(logratio,
            pBsim[1:-1],
            color=col[1],
            marker='s',
            ms=22,
            mec=col[1],
            clip_on=False,
            zorder=10)
    pl.plot([logratio[0] - .5, logratio[-1] + .5], [pB[0], pB[-1]],
            's',
            ms=30,
            mfc='w',
            mec=col[0],
            mew=3,
            color=col[0],
            zorder=10,
            clip_on=False)
    pl.plot([logratio[0] - .5, logratio[-1] + .5], [pBsim[0], pBsim[-1]],
            's',
            color=col[1],
            ms=22,
            mec=col[1],
            clip_on=False,
            zorder=10)
    pl.xticks([
        logratio[0] - .5,
        -1,
        0,
        1,
        logratio[-1] + .5,
    ], [r"$-\infty$", -1, 0, 1, r"$\infty$"])
    tt = pl.gca().get_xaxis().majorTicks
    for i in range(len(tt)):
        tt[i].set_pad(17)
    pl.yticks([0, .5, 1], [0, 50, 100])
    pl.xlim(logratio[0] - .65, logratio[-1] + .5)
    pl.ylim(0, 1)
    pl.xlabel('log(V(B)/V(A))', labelpad=-2)
    pl.ylabel('B choice [\%]', labelpad=-15)
    pl.text(.05, .83, 'A=' + str(ratio) + 'B', transform=pl.gca().transAxes)
    pl.subplots_adjust(.17, .23, .96, .95)
    simpleaxis(pl.gca())
    # broken axis
    pl.plot([logratio[0] - .27, logratio[0] - .23], [0, 0],
            c='w',
            lw=2,
            clip_on=False,
            zorder=11)
    pl.plot([logratio[0] - .29, logratio[0] - .25], [-.02, .02],
            c='k',
            lw=2,
            clip_on=False,
            zorder=11)
    pl.plot([logratio[0] - .25, logratio[0] - .21], [-.02, .02],
            c='k',
            lw=2,
            clip_on=False,
            zorder=11)
    pl.plot([logratio[-1] + .28, logratio[-1] + .24], [0, 0],
            c='w',
            lw=2,
            clip_on=False,
            zorder=11)
    pl.plot([logratio[-1] + .3, logratio[-1] + .26], [.02, -.02],
            c='k',
            lw=2,
            clip_on=False,
            zorder=11)
    pl.plot([logratio[-1] + .26, logratio[-1] + .22], [.02, -.02],
            c='k',
            lw=2,
            clip_on=False,
            zorder=11)
    if savefig:
        pl.savefig('Padoa' + w + '.pdf', dpi=600)
    else:
        pl.show()
    return choice
        [[net.R4Pi2(net.get_policy2(step, i, T), net.pstart_state) for i in S]
         for T in Tls])
    perf = (rew - R0) / (Rmax - R0)
    perf[0] = 0
    np.save('results/performance', perf)
pl.figure()
errorfill(range(len(perf)),
          np.mean(perf, axis=1),
          yerr=np.std(perf, axis=1) / np.sqrt(len(perf[0])))
pl.xticks([0, 200, 400], [0, 200, 400])
pl.yticks([0, .5, 1.0], [0, .5, 1.0])
pl.xlim([0, 400])
pl.ylim([0, 1])
pl.xlabel('Time [ms]')
pl.ylabel('Performance')
simpleaxis(pl.gca())
pl.tight_layout(0)
pl.savefig('performance.pdf', dpi=600)

# learning via parallel sampling
eta, T, r0 = .01, 1000, 0
try:
    perf = np.load('results/learn_performance.npy')
except IOError:
    try:
        rew = np.load('results/learn.npz')['R']
    except IOError:
        res = np.array([
            net.parallel_sampling_keepU(step,
                                        eta,
                                        run,
示例#18
0
fig = plt.figure(figsize=(6.5, 6))
for ssub, col in ((activityDSlr, cyan), (activityDS, orange)):
    cor, mx = foo(ssub, activityDS[1], shapes)
    if col == cyan:
        mapIdx = mx
    plt.plot(dsls,
             np.median(cor, 0),
             lw=3,
             c=col,
             clip_on=False,
             label='1 phase\n imaging' if col == cyan else '2 phase\n imaging')
    IQRfill(cor, dsls, col)
plt.xlabel('Spatial decimation')
ax = plt.gca()
simpleaxis(ax)
ax.patch.set_visible(False)
plt.xticks(dsls, ['1x1', '', '', '4x4', '6x6', '8x8', '12x12'])
plt.ylim(.55, 1)
plt.yticks(*[[.6, .8, 1.]] * 2)
plt.ylabel('Correlation', labelpad=0)
plt.legend(loc=(.01, .25), ncol=1)
plt.subplots_adjust(.145, .15, .885, .97)
ax2 = ax.twinx()
ax2.set_zorder(-1)
ax2.spines['top'].set_visible(False)
z = np.mean((activityDS[1].T.dot(shapes.reshape(len(shapes), -1)) -
             data.reshape(activity.shape[1], -1))**2)
ax2.plot(dsls, [
    np.mean((activityDSlr[ds].T.dot(shapesDSlr[ds].repeat(ds, 1).repeat(
        ds, 2).reshape(len(shapes), -1)) - data.reshape(activity.shape[1], -1))