s1 = base1+mag1*np.sin(times*2*np.pi*freq1+phase1) # freqError0 = np.abs(freq-freq0) # freqError1 = np.abs(freq-freq1) rmseError0 = rmse(X[int(tTrans/dt):,0], s0[int(tTrans/dt):]) rmseError1 = rmse(X[int(tTrans/dt):,1], s1[int(tTrans/dt):]) # error0 = (1+freqError0) * rmseError0 # error1 = (1+freqError1) * rmseError1 error0 = rmseError0 error1 = rmseError1 errors[test] = (error0 + error1)/2 fig, (ax, ax2) = plt.subplots(nrows=2, ncols=1, sharex=True) ax.plot(times, X[:,0], label="estimate (dim 0)") ax.plot(times, s0, label="target (dim 0)") ax.set(ylabel='state', title='freq=%.3f, rmse=%.3f'%(freq0, error0), xlim=((0, t)), ylim=((-1.2, 1.2))) ax.legend(loc='upper left') ax2.plot(times, X[:,1], label="estimate (dim 1)") ax2.plot(times, s1, label="target (dim 1)") ax2.set(xlabel='time', ylabel='state', title='freq=%.3f, rmse=%.3f'%(freq1, error1), xlim=((0, t)), ylim=((-1.2, 1.2))) ax2.legend(loc='upper left') sns.despine() fig.savefig("plots/oscillateNew_%s_test%s.pdf"%(neuron_type, test)) plt.close('all') print('%s errors:'%neuron_type, errors) np.savez("data/oscillateNew_%s.npz"%neuron_type, d1=d1, tauRise1=tauRise1, tauFall1=tauFall1, e1=e1, d2=d2, tauRise2=tauRise2, tauFall2=tauFall2, e2=e2, errors=errors) return errors # errorsLIF = run(neuron_type=LIF(), load=False, file="data/oscillateNew_LIF().npz") # errorsALIF = run(neuron_type=ALIF(), load=False, file="data/oscillateNew_ALIF().npz") # errorsWilson = run(neuron_type=Wilson(), dt=1e-4, load=False, file="data/oscillateNew_Wilson().npz") errorsBio = run(neuron_type=Bio("Pyramidal"), load=True, muFreq=4, base=True, tTrans=5.0, file="data/oscillateNew_Bio().npz")
def go(t=10, m=Uniform(30, 30), i=Uniform(0, 0), seed=0, dt=0.001, f=DoubleExp(1e-3, 1e-1), fS=DoubleExp(1e-3, 1e-1), d1=None, f1=None, e1=None, l1=False, stim=lambda t: np.sin(t)): if not f1: f1=f with nengo.Network(seed=seed) as model: # Stimulus and Nodes inpt = nengo.Node(stim) tar = nengo.Ensemble(1, 1, neuron_type=nengo.Direct()) pre = nengo.Ensemble(100, 1, max_rates=m, seed=seed, neuron_type=LIF()) lif = nengo.Ensemble(1, 1, max_rates=m, intercepts=i, encoders=Choice([[1]]), neuron_type=LIF(), seed=seed) wilson = nengo.Ensemble(1, 1, max_rates=m, intercepts=i, encoders=Choice([[1]]), neuron_type=Wilson(), seed=seed) bio = nengo.Ensemble(1, 1, max_rates=m, intercepts=i, encoders=Choice([[1]]), neuron_type=Bio("Pyramidal"), seed=seed) nengo.Connection(inpt, pre, synapse=None, seed=seed) cLif = nengo.Connection(pre, lif, synapse=f1, seed=seed, solver=NoSolver(d1)) cWilson = nengo.Connection(pre, wilson, synapse=f1, seed=seed, solver=NoSolver(d1)) cBio = nengo.Connection(pre, bio, synapse=f1, seed=seed, solver=NoSolver(d1)) pInpt = nengo.Probe(inpt, synapse=None) pPre = nengo.Probe(pre.neurons, synapse=None) pLif = nengo.Probe(lif.neurons, synapse=None) pWilson = nengo.Probe(wilson.neurons, synapse=None) pBio = nengo.Probe(bio.neurons, synapse=None) if l1: learnEncoders(cBio, lif, fS, alpha=3e-7) # Encoder Learning (Bio) with nengo.Simulator(model, seed=seed, dt=dt, progress_bar=False) as sim: setWeights(cBio, d1, e1) neuron.h.init() sim.run(t, progress_bar=True) reset_neuron(sim, model) e1 = cBio.e if l1 else e1 return dict( times=sim.trange(), inpt=sim.data[pInpt], pre=sim.data[pPre], lif=sim.data[pLif], wilson=sim.data[pWilson], bio=sim.data[pBio], e1=e1, )
"%s_test%s" % (neuron_type, test), t) A = f2.filt(data['ens'], dt=dt) X = np.dot(A, d2) Y = f.filt(data['tar'][:, 0] * data['tar'][:, 1], dt=dt).reshape(-1, 1) plotState(data['times'], X, Y, rmse(X, Y), "multiplyNew", "%s_pretest%s" % (neuron_type, test), t) print('%s errors:' % neuron_type, errors) np.savez("data/multiplyNew_%s.npz" % neuron_type, d1=d1, tauRise1=tauRise1, tauFall1=tauFall1, e1=e1, d2=d2, tauRise2=tauRise2, tauFall2=tauFall2, e2=e2, d3=d3, tauRise3=tauRise3, tauFall3=tauFall3, errors=errors) return errors # errorsLIF = run(neuron_type=LIF(), load=False, file="data/multiplyNew_LIF().npz") # errorsALIF = run(neuron_type=ALIF(), load=False, file="data/multiplyNew_ALIF().npz") # errorsWilson = run(neuron_type=Wilson(), dt=1e-4, load=False, file="data/multiplyNew_Wilson().npz") errorsBio = run(neuron_type=Bio("Pyramidal"), load=False, file="data/multiplyNew_Bio().npz")
# print('errors:', errors) # np.savez(file, errors=errors, dPreA=dPreA, dPreB=dPreB, ePreA=ePreA, ePreB=ePreB, dEns=dEns, tauRiseEns=tauRiseEns, tauFallEns=tauFallEns, eBio=eBio) return data['times'], X, Y times, XLIF, Y = run(neuron_type=LIF(), dt=1e-4, load=[0, 1, 2, 3, 4], nTest=1) times, XALIF, Y = run(neuron_type=ALIF(), dt=1e-4, load=[0, 1, 2, 3, 4], nTest=1) times, XWilson, Y = run(neuron_type=Wilson(), dt=1e-4, load=[0, 1, 2, 3, 4], nTest=1) times, XBio, Y = run(neuron_type=Bio("Pyramidal"), dt=1e-4, load=[0, 1, 2, 3, 4], nTest=1) eLIF = rmse(XLIF, Y) eALIF = rmse(XALIF, Y) eWilson = rmse(XWilson, Y) eBio = rmse(XBio, Y) yticks = np.array([-1, 0, 1]) xticks = np.array([0, 2, 4, 6, 8, 10]) fig, ((ax0, ax1), (ax2, ax3)) = plt.subplots(nrows=2, ncols=2, figsize=((6, 4)), sharex=True, sharey=True)
def go(NPre=100, N=100, t=10, c=None, seed=1, dt=0.001, tTrans=0.01, stage=None, alpha=3e-7, eMax=1e-1, Tff=0.3, fPre=DoubleExp(1e-3, 1e-1), fNMDA=DoubleExp(10.6e-3, 285e-3), fGABA=DoubleExp(0.5e-3, 1.5e-3), fS=DoubleExp(1e-3, 1e-1), dPreA=None, dPreB=None, dPreC=None, dPreD=None, dFdfw=None, dEns=None, dOff=None, ePreAFdfw=None, ePreBEns=None, ePreCOff=None, eFdfwEns=None, eEnsEns=None, ePreDEns=None, eOffFdfw=None, stimA=lambda t: 0, stimB=lambda t: 0, stimC=lambda t: 0, stimD=lambda t: 0, DA=lambda t: 0): if not c: c = t with nengo.Network(seed=seed) as model: inptA = nengo.Node(stimA) inptB = nengo.Node(stimB) inptC = nengo.Node(stimC) inptD = nengo.Node(stimD) preA = nengo.Ensemble(NPre, 1, max_rates=Uniform(30, 30), seed=seed) preB = nengo.Ensemble(NPre, 1, max_rates=Uniform(30, 30), seed=seed) preC = nengo.Ensemble(NPre, 1, max_rates=Uniform(30, 30), seed=seed) preD = nengo.Ensemble(NPre, 1, max_rates=Uniform(30, 30), seed=seed) fdfw = nengo.Ensemble(N, 1, neuron_type=Bio("Pyramidal", DA=DA), seed=seed) ens = nengo.Ensemble(N, 1, neuron_type=Bio("Pyramidal", DA=DA), seed=seed + 1) off = nengo.Ensemble(N, 1, neuron_type=Bio("Interneuron", DA=DA), seed=seed + 3) tarFdfw = nengo.Ensemble(N, 1, max_rates=Uniform(30, 30), intercepts=Uniform(-0.8, 0.8), seed=seed) tarEns = nengo.Ensemble(N, 1, max_rates=Uniform(30, 30), intercepts=Uniform(-0.8, 0.8), seed=seed + 1) tarOff = nengo.Ensemble(N, 1, max_rates=Uniform(30, 30), intercepts=Uniform(0.2, 0.8), encoders=Choice([[1]]), seed=seed + 3) cA = nengo.Connection(inptA, preA, synapse=None, seed=seed) cB = nengo.Connection(inptB, preB, synapse=None, seed=seed) cC = nengo.Connection(inptC, preC, synapse=None, seed=seed) cD = nengo.Connection(inptD, preD, synapse=None, seed=seed) pInptA = nengo.Probe(inptA, synapse=None) pInptB = nengo.Probe(inptB, synapse=None) pInptC = nengo.Probe(inptC, synapse=None) pInptD = nengo.Probe(inptD, synapse=None) pPreA = nengo.Probe(preA.neurons, synapse=None) pPreB = nengo.Probe(preB.neurons, synapse=None) pPreC = nengo.Probe(preC.neurons, synapse=None) pPreD = nengo.Probe(preD.neurons, synapse=None) pFdfw = nengo.Probe(fdfw.neurons, synapse=None) pTarFdfw = nengo.Probe(tarFdfw.neurons, synapse=None) pEns = nengo.Probe(ens.neurons, synapse=None) pTarEns = nengo.Probe(tarEns.neurons, synapse=None) pOff = nengo.Probe(off.neurons, synapse=None) pTarOff = nengo.Probe(tarOff.neurons, synapse=None) if stage == 0: nengo.Connection(preD, tarEns, synapse=fPre, seed=seed) c0 = nengo.Connection(preD, ens, synapse=fPre, solver=NoSolver(dPreD), seed=seed) learnEncoders(c0, tarEns, fS, alpha=10 * alpha, eMax=10 * eMax, tTrans=tTrans) if stage == 1: nengo.Connection(inptA, tarFdfw, synapse=fPre, seed=seed) nengo.Connection(inptB, tarEns, synapse=fPre, seed=seed) nengo.Connection(inptC, tarOff, synapse=fPre, seed=seed) c0 = nengo.Connection(preD, ens, synapse=fPre, solver=NoSolver(dPreD), seed=seed) c1 = nengo.Connection(preA, fdfw, synapse=fPre, solver=NoSolver(dPreA), seed=seed) c2 = nengo.Connection(preB, ens, synapse=fPre, solver=NoSolver(dPreB), seed=seed) c3 = nengo.Connection(preC, off, synapse=fPre, solver=NoSolver(dPreC), seed=seed) learnEncoders(c1, tarFdfw, fS, alpha=3 * alpha, eMax=3 * eMax, tTrans=tTrans) learnEncoders(c2, tarEns, fS, alpha=10 * alpha, eMax=10 * eMax, tTrans=tTrans) learnEncoders(c3, tarOff, fS, alpha=3 * alpha, eMax=3 * eMax, tTrans=tTrans) if stage == 2: c1 = nengo.Connection(preA, fdfw, synapse=fPre, solver=NoSolver(dPreA), seed=seed) c2 = nengo.Connection(preC, off, synapse=fPre, solver=NoSolver(dPreC), seed=seed) if stage == 3: cB.synapse = fNMDA ff = nengo.Ensemble(1, 1, neuron_type=nengo.Direct()) fb = nengo.Ensemble(1, 1, neuron_type=nengo.Direct()) nengo.Connection(inptA, ff, synapse=fPre, seed=seed) nengo.Connection(inptB, fb, synapse=fNMDA, seed=seed) nengo.Connection(fb, tarEns, synapse=fPre, seed=seed) nengo.Connection(ff, tarEns, synapse=fNMDA, transform=Tff, seed=seed) c0 = nengo.Connection(preD, ens, synapse=fPre, solver=NoSolver(dPreD), seed=seed) c1 = nengo.Connection(preA, fdfw, synapse=fPre, solver=NoSolver(dPreA), seed=seed) c2 = nengo.Connection(preB, ens, synapse=fPre, solver=NoSolver(dPreB), seed=seed) c3 = nengo.Connection(fdfw, ens, synapse=NMDA(), solver=NoSolver(dFdfw), seed=seed) learnEncoders(c3, tarEns, fS, alpha=3 * alpha, eMax=3 * eMax, tTrans=tTrans) if stage == 4: cB.synapse = fNMDA c0 = nengo.Connection(preD, ens, synapse=fPre, solver=NoSolver(dPreD), seed=seed) c1 = nengo.Connection(preA, fdfw, synapse=fPre, solver=NoSolver(dPreA), seed=seed) c2 = nengo.Connection(preB, ens, synapse=fPre, solver=NoSolver(dPreB), seed=seed) c3 = nengo.Connection(fdfw, ens, synapse=NMDA(), solver=NoSolver(dFdfw), seed=seed) if stage == 5: preB2 = nengo.Ensemble(NPre, 1, max_rates=Uniform(30, 30), seed=seed) ens2 = nengo.Ensemble(N, 1, neuron_type=Bio("Pyramidal", DA=DA), seed=seed + 1) ens3 = nengo.Ensemble(N, 1, neuron_type=Bio("Pyramidal", DA=DA), seed=seed + 1) nengo.Connection(inptB, preB2, synapse=fNMDA, seed=seed) c0a = nengo.Connection(preD, ens, synapse=fPre, solver=NoSolver(dPreD), seed=seed) c0b = nengo.Connection(preD, ens2, synapse=fPre, solver=NoSolver(dPreD), seed=seed) c0c = nengo.Connection(preD, ens3, synapse=fPre, solver=NoSolver(dPreD), seed=seed) c1 = nengo.Connection(preA, fdfw, synapse=fPre, solver=NoSolver(dPreA), seed=seed) c2 = nengo.Connection(fdfw, ens, synapse=NMDA(), solver=NoSolver(dFdfw), seed=seed) c3 = nengo.Connection(preB, ens2, synapse=fPre, solver=NoSolver(dPreB), seed=seed) c4 = nengo.Connection(ens2, ens, synapse=NMDA(), solver=NoSolver(dEns), seed=seed) c5 = nengo.Connection(fdfw, ens3, synapse=NMDA(), solver=NoSolver(dFdfw), seed=seed) c6 = nengo.Connection(preB2, ens3, synapse=fPre, solver=NoSolver(dPreB), seed=seed) learnEncoders(c4, ens3, fS, alpha=alpha, eMax=eMax, tTrans=tTrans) pTarEns = nengo.Probe(ens3.neurons, synapse=None) if stage == 9: c0 = nengo.Connection(preD, ens, synapse=fPre, solver=NoSolver(dPreD), seed=seed) c1 = nengo.Connection(preA, fdfw, synapse=fPre, solver=NoSolver(dPreA), seed=seed) c2 = nengo.Connection(fdfw, ens, synapse=NMDA(), solver=NoSolver(dFdfw), seed=seed) c3 = nengo.Connection(ens, ens, synapse=NMDA(), solver=NoSolver(dEns), seed=seed) c6 = nengo.Connection(preC, off, synapse=fPre, solver=NoSolver(dPreC), seed=seed) c7 = nengo.Connection(off, fdfw, synapse=GABA(), solver=NoSolver(dOff), seed=seed) with nengo.Simulator(model, seed=seed, dt=dt, progress_bar=False) as sim: if stage == 0: setWeights(c0, dPreD, ePreDEns) if stage == 1: setWeights(c0, dPreD, ePreDEns) setWeights(c1, dPreA, ePreAFdfw) setWeights(c2, dPreB, ePreBEns) setWeights(c3, dPreC, ePreCOff) if stage == 2: setWeights(c1, dPreA, ePreAFdfw) setWeights(c2, dPreC, ePreCOff) if stage == 3: setWeights(c0, dPreD, ePreDEns) setWeights(c1, dPreA, ePreAFdfw) setWeights(c2, dPreB, ePreBEns) setWeights(c3, dFdfw, eFdfwEns) if stage == 4: setWeights(c0, dPreD, ePreDEns) setWeights(c1, dPreA, ePreAFdfw) setWeights(c2, dPreB, ePreBEns) setWeights(c3, dFdfw, eFdfwEns) if stage == 5: setWeights(c0a, dPreD, ePreDEns) setWeights(c0b, dPreD, ePreDEns) setWeights(c0c, dPreD, ePreDEns) setWeights(c1, dPreA, ePreAFdfw) setWeights(c2, dFdfw, eFdfwEns) setWeights(c3, dPreB, ePreBEns) setWeights(c4, dEns, eEnsEns) setWeights(c5, dFdfw, eFdfwEns) setWeights(c6, dPreB, ePreBEns) if stage == 9: setWeights(c0, dPreD, ePreDEns) setWeights(c1, dPreA, ePreAFdfw) setWeights(c2, dFdfw, eFdfwEns) setWeights(c3, dEns, eEnsEns) setWeights(c6, dPreC, ePreCOff) setWeights(c7, dOff, eOffFdfw) neuron.h.init() sim.run(t, progress_bar=True) reset_neuron(sim, model) ePreDEns = c0.e if stage == 0 else ePreDEns ePreAFdfw = c1.e if stage == 1 else ePreAFdfw ePreBEns = c2.e if stage == 1 else ePreBEns ePreCOff = c3.e if stage == 1 else ePreCOff eFdfwEns = c3.e if stage == 3 else eFdfwEns eEnsEns = c4.e if stage == 5 else eEnsEns return dict( times=sim.trange(), inptA=sim.data[pInptA], inptB=sim.data[pInptB], inptC=sim.data[pInptC], inptD=sim.data[pInptD], preA=sim.data[pPreA], preB=sim.data[pPreB], preC=sim.data[pPreC], preD=sim.data[pPreD], fdfw=sim.data[pFdfw], ens=sim.data[pEns], off=sim.data[pOff], tarFdfw=sim.data[pTarFdfw], tarEns=sim.data[pTarEns], tarOff=sim.data[pTarOff], ePreAFdfw=ePreAFdfw, ePreBEns=ePreBEns, ePreCOff=ePreCOff, ePreDEns=ePreDEns, eFdfwEns=eFdfwEns, eEnsEns=eEnsEns, eOffFdfw=eOffFdfw, )
def go(NPre=100, N=100, t=10, c=None, seed=0, dt=0.001, Tff=0.3, tTrans=0.01, stage=None, alpha=3e-7, eMax=1e-1, fPre=DoubleExp(1e-3, 1e-1), fNMDA=DoubleExp(10.6e-3, 285e-3), fS=DoubleExp(1e-3, 1e-1), dPreA=None, dPreB=None, dPreC=None, dFdfw=None, dBio=None, dNeg=None, dInh=None, ePreA=None, ePreB=None, ePreC=None, eFdfw=None, eBio=None, eNeg=None, eInh=None, stimA=lambda t: 0, stimB=lambda t: 0, stimC=lambda t: 0, DA=lambda t: 0): if not c: c = t with nengo.Network(seed=seed) as model: inptA = nengo.Node(stimA) inptB = nengo.Node(stimB) inptC = nengo.Node(stimC) preA = nengo.Ensemble(NPre, 1, max_rates=Uniform(30, 30), seed=seed) preB = nengo.Ensemble(NPre, 1, max_rates=Uniform(30, 30), seed=seed) preC = nengo.Ensemble(NPre, 1, max_rates=Uniform(30, 30), seed=seed) fdfw = nengo.Ensemble(N, 1, neuron_type=Bio("Pyramidal", DA=DA), seed=seed) ens = nengo.Ensemble(N, 1, neuron_type=Bio("Pyramidal", DA=DA), seed=seed+1) inh = nengo.Ensemble(N, 1, neuron_type=Bio("Interneuron", DA=DA), seed=seed+2) tarFdfw = nengo.Ensemble(N, 1, max_rates=Uniform(30, 30), intercepts=Uniform(-0.8, 0.8), neuron_type=nengo.LIF(), seed=seed) tarEns = nengo.Ensemble(N, 1, max_rates=Uniform(30, 30), intercepts=Uniform(-0.8, 0.8), neuron_type=nengo.LIF(), seed=seed+1) tarInh = nengo.Ensemble(N, 1, max_rates=Uniform(30, 30), intercepts=Uniform(0.2, 0.8), encoders=Choice([[1]]), neuron_type=nengo.LIF(), seed=seed+2) cA = nengo.Connection(inptA, preA, synapse=None, seed=seed) cB = nengo.Connection(inptB, preB, synapse=None, seed=seed) cC = nengo.Connection(inptC, preC, synapse=None, seed=seed) pInptA = nengo.Probe(inptA, synapse=None) pInptB = nengo.Probe(inptB, synapse=None) pInptC = nengo.Probe(inptC, synapse=None) pPreA = nengo.Probe(preA.neurons, synapse=None) pPreB = nengo.Probe(preB.neurons, synapse=None) pPreC = nengo.Probe(preC.neurons, synapse=None) pFdfw = nengo.Probe(fdfw.neurons, synapse=None) pTarFdfw = nengo.Probe(tarFdfw.neurons, synapse=None) pEns = nengo.Probe(ens.neurons, synapse=None) pTarEns = nengo.Probe(tarEns.neurons, synapse=None) pInh = nengo.Probe(inh.neurons, synapse=None) pTarInh = nengo.Probe(tarInh.neurons, synapse=None) if stage==1: nengo.Connection(inptA, tarFdfw, synapse=fPre, seed=seed) nengo.Connection(inptB, tarEns, synapse=fPre, seed=seed) nengo.Connection(inptC, tarInh, synapse=fPre, seed=seed) c1 = nengo.Connection(preA, fdfw, synapse=fPre, solver=NoSolver(dPreA), seed=seed) c2 = nengo.Connection(preB, ens, synapse=fPre, solver=NoSolver(dPreB), seed=seed) c3 = nengo.Connection(preC, inh, synapse=fPre, solver=NoSolver(dPreC), seed=seed) learnEncoders(c1, tarFdfw, fS, alpha=alpha, eMax=eMax, tTrans=tTrans) learnEncoders(c2, tarEns, fS, alpha=3*alpha, eMax=10*eMax, tTrans=tTrans) learnEncoders(c3, tarInh, fS, alpha=alpha/3, eMax=eMax, tTrans=tTrans) if stage==2: c1 = nengo.Connection(preA, fdfw, synapse=fPre, solver=NoSolver(dPreA), seed=seed) c2 = nengo.Connection(preC, inh, synapse=fPre, solver=NoSolver(dPreC), seed=seed) if stage==3: cB.synapse = fNMDA nengo.Connection(inptA, tarFdfw, synapse=fPre, seed=seed) nengo.Connection(inptB, tarEns, synapse=fPre, seed=seed) nengo.Connection(tarFdfw, tarEns, synapse=fNMDA, transform=Tff, seed=seed) c1 = nengo.Connection(preA, fdfw, synapse=fPre, solver=NoSolver(dPreA), seed=seed) c2 = nengo.Connection(preB, ens, synapse=fPre, solver=NoSolver(dPreB), seed=seed) c3 = nengo.Connection(fdfw, ens, synapse=NMDA(), solver=NoSolver(dFdfw), seed=seed) learnEncoders(c3, tarEns, fS, alpha=alpha, eMax=eMax, tTrans=tTrans) if stage==4: cB.synapse = fNMDA c1 = nengo.Connection(preA, fdfw, synapse=fPre, solver=NoSolver(dPreA), seed=seed) c2 = nengo.Connection(preB, ens, synapse=fPre, solver=NoSolver(dPreB), seed=seed) c3 = nengo.Connection(fdfw, ens, synapse=NMDA(), solver=NoSolver(dFdfw), seed=seed) if stage==5: preB2 = nengo.Ensemble(NPre, 1, max_rates=Uniform(30, 30), seed=seed) ens2 = nengo.Ensemble(N, 1, neuron_type=Bio("Pyramidal", DA=DA), seed=seed+1) ens3 = nengo.Ensemble(N, 1, neuron_type=Bio("Pyramidal", DA=DA), seed=seed+1) nengo.Connection(inptB, preB2, synapse=fNMDA, seed=seed) c1 = nengo.Connection(preA, fdfw, synapse=fPre, solver=NoSolver(dPreA), seed=seed) c2 = nengo.Connection(fdfw, ens, synapse=NMDA(), solver=NoSolver(dFdfw), seed=seed) c3 = nengo.Connection(preB, ens2, synapse=fPre, solver=NoSolver(dPreB), seed=seed) c4 = nengo.Connection(ens2, ens, synapse=NMDA(), solver=NoSolver(dBio), seed=seed) c5 = nengo.Connection(fdfw, ens3, synapse=NMDA(), solver=NoSolver(dFdfw), seed=seed) c6 = nengo.Connection(preB2, ens3, synapse=fPre, solver=NoSolver(dPreB), seed=seed) learnEncoders(c4, ens3, fS, alpha=alpha, eMax=eMax, tTrans=tTrans) pTarEns = nengo.Probe(ens3.neurons, synapse=None) if stage==6: preA2 = nengo.Ensemble(NPre, 1, max_rates=Uniform(30, 30), seed=seed) fdfw2 = nengo.Ensemble(N, 1, neuron_type=Bio("Pyramidal", DA=DA), seed=seed) fdfw3 = nengo.Ensemble(N, 1, neuron_type=Bio("Pyramidal", DA=DA), seed=seed) fdfw4 = nengo.Ensemble(N, 1, neuron_type=Bio("Pyramidal", DA=DA), seed=seed) tarFdfw4 = nengo.Ensemble(N, 1, max_rates=Uniform(30, 30), intercepts=Uniform(-0.8, 0.8), neuron_type=nengo.LIF(), seed=seed) nengo.Connection(inptA, tarFdfw4, synapse=fPre, seed=seed) nengo.Connection(inptB, preA2, synapse=fNMDA, seed=seed) nengo.Connection(inptC, tarInh, synapse=fPre, seed=seed) nengo.Connection(tarInh, tarFdfw4.neurons, synapse=None, transform=-1e2*np.ones((N, 1)), seed=seed) c1 = nengo.Connection(preB, ens, synapse=fPre, solver=NoSolver(dPreB), seed=seed) c2 = nengo.Connection(ens, fdfw2, synapse=NMDA(), solver=NoSolver(dNeg), seed=seed) c3 = nengo.Connection(preA2, fdfw3, synapse=fPre, solver=NoSolver(dPreA), seed=seed) c4 = nengo.Connection(preA, fdfw4, synapse=fPre, solver=NoSolver(dPreA), seed=seed) c5 = nengo.Connection(preC, inh, synapse=fPre, solver=NoSolver(dPreC), seed=seed) c6 = nengo.Connection(inh, fdfw4, synapse=GABA(), solver=NoSolver(dInh), seed=seed) learnEncoders(c2, fdfw3, fS, alpha=alpha, eMax=eMax, tTrans=tTrans) learnEncoders(c6, tarFdfw4, fS, alpha=1e3*alpha, eMax=1e3*eMax, tTrans=tTrans, inh=True) pFdfw2 = nengo.Probe(fdfw2.neurons, synapse=None) pFdfw4 = nengo.Probe(fdfw4.neurons, synapse=None) pTarFdfw2 = nengo.Probe(fdfw3.neurons, synapse=None) pTarFdfw4 = nengo.Probe(tarFdfw4.neurons, synapse=None) if stage==7: c1 = nengo.Connection(preA, fdfw, synapse=fPre, solver=NoSolver(dPreA), seed=seed) c2 = nengo.Connection(fdfw, ens, synapse=NMDA(), solver=NoSolver(dFdfw), seed=seed) c3 = nengo.Connection(ens, ens, synapse=NMDA(), solver=NoSolver(dBio), seed=seed) c4 = nengo.Connection(ens, fdfw, synapse=NMDA(), solver=NoSolver(dNeg), seed=seed) c5 = nengo.Connection(preC, inh, synapse=fPre, solver=NoSolver(dPreC), seed=seed) c6 = nengo.Connection(inh, fdfw, synapse=GABA(), solver=NoSolver(dInh), seed=seed) with nengo.Simulator(model, seed=seed, dt=dt, progress_bar=False) as sim: if stage==1: setWeights(c1, dPreA, ePreA) setWeights(c2, dPreB, ePreB) setWeights(c3, dPreC, ePreC) if stage==2: setWeights(c1, dPreA, ePreA) setWeights(c2, dPreC, ePreC) if stage==3: setWeights(c1, dPreA, ePreA) setWeights(c2, dPreB, ePreB) setWeights(c3, dFdfw, eFdfw) if stage==4: setWeights(c1, dPreA, ePreA) setWeights(c2, dPreB, ePreB) setWeights(c3, dFdfw, eFdfw) if stage==5: setWeights(c1, dPreA, ePreA) setWeights(c2, dFdfw, eFdfw) setWeights(c3, dPreB, ePreB) setWeights(c4, dBio, eBio) setWeights(c5, dFdfw, eFdfw) setWeights(c6, dPreB, ePreB) if stage==6: setWeights(c1, dPreB, ePreB) setWeights(c2, dNeg, eNeg) setWeights(c3, dPreA, ePreA) setWeights(c4, dPreA, ePreA) setWeights(c5, dPreC, ePreC) setWeights(c6, dInh, eInh) if stage==7: setWeights(c1, dPreA, ePreA) setWeights(c2, dFdfw, eFdfw) setWeights(c3, dBio, eBio) setWeights(c4, dNeg, eNeg) setWeights(c5, dPreC, ePreC) setWeights(c6, dInh, eInh) neuron.h.init() sim.run(t, progress_bar=True) reset_neuron(sim, model) ePreA = c1.e if stage==1 else ePreA ePreB = c2.e if stage==1 else ePreB ePreC = c3.e if stage==1 else ePreC eFdfw = c3.e if stage==3 else eFdfw eBio = c4.e if stage==5 else eBio eNeg = c2.e if stage==6 else eNeg eInh = c6.e if stage==6 else eInh return dict( times=sim.trange(), inptA=sim.data[pInptA], inptB=sim.data[pInptB], inptC=sim.data[pInptC], preA=sim.data[pPreA], preB=sim.data[pPreB], preC=sim.data[pPreC], fdfw=sim.data[pFdfw], ens=sim.data[pEns], inh=sim.data[pInh], tarFdfw=sim.data[pTarFdfw], tarEns=sim.data[pTarEns], tarInh=sim.data[pTarInh], fdfw2=sim.data[pFdfw2] if stage==6 else None, fdfw4=sim.data[pFdfw4] if stage==6 else None, tarFdfw2=sim.data[pTarFdfw2] if stage==6 else None, tarFdfw4=sim.data[pTarFdfw4] if stage==6 else None, ePreA=ePreA, ePreB=ePreB, ePreC=ePreC, eFdfw=eFdfw, eBio=eBio, eNeg=eNeg, eInh=eInh, )
def go(NPre=300, NBias=100, N=30, t=10, seed=1, dt=0.001, Tff=0.3, tTrans=0.01, stage=None, alpha=3e-7, eMax=1e-1, fPre=DoubleExp(1e-3, 1e-1), fNMDA=DoubleExp(10.6e-3, 285e-3), fGABA=DoubleExp(0.5e-3, 1.5e-3), fS=DoubleExp(1e-3, 1e-1), dPre=None, dFdfw=None, dEns=None, dBias=None, ePreFdfw=None, ePreEns=None, ePreBias=None, eFdfwEns=None, eBiasEns=None, eEnsEns=None, stimA=lambda t: 0, stimB=lambda t: 0, stimC=lambda t: 0.01, DA=lambda t: 0): with nengo.Network(seed=seed) as model: inptA = nengo.Node(stimA) inptB = nengo.Node(stimB) inptC = nengo.Node(stimC) preA = nengo.Ensemble(NPre, 1, max_rates=Uniform(30, 30), seed=seed) preB = nengo.Ensemble(NPre, 1, max_rates=Uniform(30, 30), seed=seed) preC = nengo.Ensemble(NPre, 1, max_rates=Uniform(30, 30), seed=seed) fdfw = nengo.Ensemble(N, 1, neuron_type=Bio("Pyramidal", DA=DA), seed=seed) ens = nengo.Ensemble(N, 1, neuron_type=Bio("Pyramidal", DA=DA), seed=seed + 1) bias = nengo.Ensemble(NBias, 1, neuron_type=Bio("Interneuron", DA=DA), seed=seed + 2) tarFdfw = nengo.Ensemble(N, 1, max_rates=Uniform(30, 30), intercepts=Uniform(-0.8, 0.8), seed=seed) tarEns = nengo.Ensemble(N, 1, max_rates=Uniform(30, 30), intercepts=Uniform(-0.8, 0.8), seed=seed + 1) tarBias = nengo.Ensemble(NBias, 1, max_rates=Uniform(30, 30), intercepts=Uniform(-0.8, -0.2), encoders=Choice([[1]]), seed=seed + 2) cA = nengo.Connection(inptA, preA, synapse=None, seed=seed) cB = nengo.Connection(inptB, preB, synapse=None, seed=seed) cC = nengo.Connection(inptC, preC, synapse=None, seed=seed) pInptA = nengo.Probe(inptA, synapse=None) pInptB = nengo.Probe(inptB, synapse=None) pInptC = nengo.Probe(inptC, synapse=None) pPreA = nengo.Probe(preA.neurons, synapse=None) pPreB = nengo.Probe(preB.neurons, synapse=None) pPreC = nengo.Probe(preC.neurons, synapse=None) pFdfw = nengo.Probe(fdfw.neurons, synapse=None) pTarFdfw = nengo.Probe(tarFdfw.neurons, synapse=None) pEns = nengo.Probe(ens.neurons, synapse=None) pTarEns = nengo.Probe(tarEns.neurons, synapse=None) pBias = nengo.Probe(bias.neurons, synapse=None) pTarBias = nengo.Probe(tarBias.neurons, synapse=None) if stage == 0: # readout decoders for [preA, preB, preC] pass if stage == 1: # encoders for [preA, preC] to [fdfw, bias] nengo.Connection(inptA, tarFdfw, synapse=fPre, seed=seed) nengo.Connection(inptC, tarBias, synapse=fPre, seed=seed) c1 = nengo.Connection(preA, fdfw, synapse=fPre, solver=NoSolver(dPre), seed=seed) c2 = nengo.Connection(preC, bias, synapse=fPre, solver=NoSolver(dPre), seed=seed) learnEncoders(c1, tarFdfw, fS, alpha=3 * alpha, eMax=3 * eMax, tTrans=tTrans) learnEncoders(c2, tarBias, fS, alpha=alpha, eMax=eMax, tTrans=tTrans) if stage == 2: # readout decoders for fdfw and bias c1 = nengo.Connection(preA, fdfw, synapse=fPre, solver=NoSolver(dPre), seed=seed) c2 = nengo.Connection(preC, bias, synapse=fPre, solver=NoSolver(dPre), seed=seed) if stage == 3: # encoders for [bias] to ens # nengo.Connection(inptC, tarBias, synapse=fPre, seed=seed) nengo.Connection(inptC, tarEns, synapse=fGABA, seed=seed) c1 = nengo.Connection(preC, bias, synapse=fPre, solver=NoSolver(dPre), seed=seed) c2 = nengo.Connection(bias, ens, synapse=GABA(), solver=NoSolver(dBias), seed=seed) learnEncoders(c2, tarEns, fS, alpha=1e3 * alpha, eMax=1e3 * eMax, tTrans=tTrans) if stage == 4: # encoders for [preB] to [ens] # nengo.Connection(inptC, tarBias, synapse=fPre, seed=seed) nengo.Connection(inptC, tarEns, synapse=fGABA, seed=seed) nengo.Connection(inptB, tarEns, synapse=fPre, seed=seed) c1 = nengo.Connection(preC, bias, synapse=fPre, solver=NoSolver(dPre), seed=seed) c2 = nengo.Connection(bias, ens, synapse=GABA(), solver=NoSolver(dBias), seed=seed) c3 = nengo.Connection(preB, ens, synapse=fPre, solver=NoSolver(dPre), seed=seed) learnEncoders(c3, tarEns, fS, alpha=3 * alpha, eMax=3 * eMax, tTrans=tTrans) if stage == 5: # encoders for [fdfw] to ens cB.synapse = fNMDA tarPreA = nengo.Ensemble(1, 1, neuron_type=nengo.Direct()) tarPreB = nengo.Ensemble(1, 1, neuron_type=nengo.Direct()) tarPreC = nengo.Ensemble(1, 1, neuron_type=nengo.Direct()) nengo.Connection(inptA, tarPreA, synapse=fPre, seed=seed) nengo.Connection(inptB, tarPreB, synapse=fNMDA, seed=seed) nengo.Connection(inptC, tarPreC, synapse=fPre, seed=seed) nengo.Connection(tarPreA, tarEns, synapse=fNMDA, transform=Tff, seed=seed) nengo.Connection(tarPreB, tarEns, synapse=fPre, seed=seed) nengo.Connection(tarPreC, tarEns, synapse=fGABA, seed=seed) c1 = nengo.Connection(preA, fdfw, synapse=fPre, solver=NoSolver(dPre), seed=seed) c2 = nengo.Connection(preB, ens, synapse=fPre, solver=NoSolver(dPre), seed=seed) c3 = nengo.Connection(preC, bias, synapse=fPre, solver=NoSolver(dPre), seed=seed) c4 = nengo.Connection(fdfw, ens, synapse=NMDA(), solver=NoSolver(dFdfw), seed=seed) c5 = nengo.Connection(bias, ens, synapse=GABA(), solver=NoSolver(dBias), seed=seed) learnEncoders(c4, tarEns, fS, alpha=3 * alpha, eMax=3 * eMax, tTrans=tTrans) if stage == 6: # readout decoders for ens cB.synapse = fNMDA c1 = nengo.Connection(preA, fdfw, synapse=fPre, solver=NoSolver(dPre), seed=seed) c2 = nengo.Connection(preB, ens, synapse=fPre, solver=NoSolver(dPre), seed=seed) c3 = nengo.Connection(preC, bias, synapse=fPre, solver=NoSolver(dPre), seed=seed) c4 = nengo.Connection(fdfw, ens, synapse=NMDA(), solver=NoSolver(dFdfw), seed=seed) c5 = nengo.Connection(bias, ens, synapse=GABA(), solver=NoSolver(dBias), seed=seed) if stage == 7: # encoders from ens to ens preB2 = nengo.Ensemble(NPre, 1, max_rates=Uniform(30, 30), seed=seed) ens2 = nengo.Ensemble(N, 1, neuron_type=Bio("Pyramidal", DA=DA), seed=seed + 1) # acts as preB input to ens ens3 = nengo.Ensemble(N, 1, neuron_type=Bio("Pyramidal", DA=DA), seed=seed + 1) # acts as tarEns nengo.Connection(inptB, preB2, synapse=fNMDA, seed=seed) pTarEns = nengo.Probe(ens3.neurons, synapse=None) c1 = nengo.Connection(preA, fdfw, synapse=fPre, solver=NoSolver(dPre), seed=seed) c2 = nengo.Connection(preB, ens2, synapse=fPre, solver=NoSolver(dPre), seed=seed) c3 = nengo.Connection(preB2, ens3, synapse=fPre, solver=NoSolver(dPre), seed=seed) c4 = nengo.Connection(preC, bias, synapse=fPre, solver=NoSolver(dPre), seed=seed) c5 = nengo.Connection(fdfw, ens, synapse=NMDA(), solver=NoSolver(dFdfw), seed=seed) c6 = nengo.Connection(fdfw, ens3, synapse=NMDA(), solver=NoSolver(dFdfw), seed=seed) c7 = nengo.Connection(bias, ens, synapse=GABA(), solver=NoSolver(dBias), seed=seed) c8 = nengo.Connection(bias, ens2, synapse=GABA(), solver=NoSolver(dBias), seed=seed) c9 = nengo.Connection(bias, ens3, synapse=GABA(), solver=NoSolver(dBias), seed=seed) c10 = nengo.Connection(ens2, ens, synapse=NMDA(), solver=NoSolver(dEns), seed=seed) learnEncoders(c10, ens3, fS, alpha=alpha, eMax=eMax, tTrans=tTrans) if stage == 8: # test c1 = nengo.Connection(preA, fdfw, synapse=fPre, solver=NoSolver(dPre), seed=seed) c2 = nengo.Connection(preC, bias, synapse=fPre, solver=NoSolver(dPre), seed=seed) c3 = nengo.Connection(fdfw, ens, synapse=NMDA(), solver=NoSolver(dFdfw), seed=seed) c4 = nengo.Connection(bias, ens, synapse=GABA(), solver=NoSolver(dBias), seed=seed) c5 = nengo.Connection(ens, ens, synapse=NMDA(), solver=NoSolver(dEns), seed=seed) with nengo.Simulator(model, seed=seed, dt=dt, progress_bar=False) as sim: if stage == 0: pass if stage == 1: setWeights(c1, dPre, ePreFdfw) setWeights(c2, dPre, ePreBias) if stage == 2: setWeights(c1, dPre, ePreFdfw) setWeights(c2, dPre, ePreBias) if stage == 3: setWeights(c1, dPre, ePreBias) setWeights(c2, dBias, eBiasEns) if stage == 4: setWeights(c1, dPre, ePreBias) setWeights(c2, dBias, eBiasEns) setWeights(c3, dPre, ePreEns) if stage == 5: setWeights(c1, dPre, ePreFdfw) setWeights(c2, dPre, ePreEns) setWeights(c3, dPre, ePreBias) setWeights(c4, dFdfw, eFdfwEns) setWeights(c5, dBias, eBiasEns) if stage == 6: setWeights(c1, dPre, ePreFdfw) setWeights(c2, dPre, ePreEns) setWeights(c3, dPre, ePreBias) setWeights(c4, dFdfw, eFdfwEns) setWeights(c5, dBias, eBiasEns) if stage == 7: setWeights(c1, dPre, ePreFdfw) setWeights(c2, dPre, ePreEns) setWeights(c3, dPre, ePreEns) setWeights(c4, dPre, ePreBias) setWeights(c5, dFdfw, eFdfwEns) setWeights(c6, dFdfw, eFdfwEns) setWeights(c7, dBias, eBiasEns) setWeights(c8, dBias, eBiasEns) setWeights(c9, dBias, eBiasEns) setWeights(c10, dEns, eEnsEns) if stage == 8: setWeights(c1, dPre, ePreFdfw) setWeights(c2, dPre, ePreBias) setWeights(c3, dFdfw, eFdfwEns) setWeights(c4, dBias, eBiasEns) setWeights(c5, dEns, eEnsEns) neuron.h.init() sim.run(t, progress_bar=True) reset_neuron(sim, model) ePreFdfw = c1.e if stage == 1 else ePreFdfw ePreBias = c2.e if stage == 1 else ePreBias eBiasEns = c2.e if stage == 3 else eBiasEns ePreEns = c3.e if stage == 4 else ePreEns eFdfwEns = c4.e if stage == 5 else eFdfwEns eEnsEns = c10.e if stage == 7 else eEnsEns return dict( times=sim.trange(), inptA=sim.data[pInptA], inptB=sim.data[pInptB], inptC=sim.data[pInptC], preA=sim.data[pPreA], fdfw=sim.data[pFdfw], ens=sim.data[pEns], bias=sim.data[pBias], tarFdfw=sim.data[pTarFdfw], tarEns=sim.data[pTarEns], tarBias=sim.data[pTarBias], ePreFdfw=ePreFdfw, ePreEns=ePreEns, ePreBias=ePreBias, eFdfwEns=eFdfwEns, eBiasEns=eBiasEns, eEnsEns=eEnsEns, )