def run(NPre=100, N=30, t=10, nTrain=5, nEnc=5, nTest=10, dt=0.001, neuron_type=LIF(), fPre=DoubleExp(1e-3, 1e-1), fS=DoubleExp(2e-2, 2e-1), Tff=0.3, reg=1e-1, tauRiseMax=1e-1, tauFallMax=3e-1, load=[], file="data/integrate"): print('\nNeuron Type: %s' % neuron_type) file = file + f"{neuron_type}.npz" if 0 in load: dPreA = np.load(file)['dPreA'] # fix indexing of decoders dPreB = np.load(file)['dPreB'] else: print('readout decoders for preInptA and preInptB') spikesInptA = np.zeros((nTrain, int(t / 0.001), NPre)) spikesInptB = np.zeros((nTrain, int(t / 0.001), NPre)) targetsInptA = np.zeros((nTrain, int(t / 0.001), 1)) targetsInptB = np.zeros((nTrain, int(t / 0.001), 1)) for n in range(nTrain): stimA, stimB = makeSignal(t, fPre, dt=0.001, seed=n) data = go(NPre=NPre, N=N, t=t, dt=0.001, neuron_type=neuron_type, fPre=fPre, fS=fS, stimA=stimA, stimB=stimB) spikesInptA[n] = data['preInptA'] spikesInptB[n] = data['preInptB'] targetsInptA[n] = fPre.filt(Tff * data['inptA'], dt=0.001) targetsInptB[n] = fPre.filt(data['inptB'], dt=0.001) dPreA, X1a, Y1a, error1a = decodeNoF(spikesInptA, targetsInptA, nTrain, fPre, dt=0.001, reg=reg) dPreB, X1b, Y1b, error1b = decodeNoF(spikesInptB, targetsInptB, nTrain, fPre, dt=0.001, reg=reg) np.savez(file, dPreA=dPreA, dPreB=dPreB) times = np.arange(0, t * nTrain, 0.001) plotState(times, X1a, Y1a, error1a, "integrate", "%s_preInptA" % neuron_type, t * nTrain) plotState(times, X1b, Y1b, error1b, "integrate", "%s_preInptB" % neuron_type, t * nTrain) if 1 in load: ePreB = np.load(file)['ePreB'] elif isinstance(neuron_type, Bio): print("encoders for preInptB-to-ens") ePreB = np.zeros((NPre, N, 1)) for n in range(nEnc): stimA, stimB = makeSignal(t, fPre, dt=dt, seed=n) data = go(dPreA=dPreA, dPreB=dPreB, ePreB=ePreB, NPre=NPre, N=N, t=t, dt=dt, neuron_type=neuron_type, fPre=fPre, fS=fS, stimA=stimA, stimB=stimB, stage=1) ePreB = data['ePreB'] plotActivity(t, dt, fS, data['times'], data['ens'], data['tarEns'], "integrate", "preIntgToEns") np.savez(file, dPreA=dPreA, dPreB=dPreB, ePreB=ePreB) else: ePreB = np.zeros((NPre, N, 1)) if 2 in load: ePreA = np.load(file)['ePreA'] elif isinstance(neuron_type, Bio): print("encoders for preInptA-to-ens") ePreA = np.zeros((NPre, N, 1)) for n in range(nEnc): stimA, stimB = makeSignal(t, fPre, dt=dt, seed=n) data = go(dPreA=dPreA, dPreB=dPreB, ePreA=ePreA, ePreB=ePreB, fPre=fPre, NPre=NPre, N=N, t=t, dt=dt, neuron_type=neuron_type, fS=fS, stimA=stimA, stimB=stimB, stage=2) ePreA = data['ePreA'] plotActivity(t, dt, fS, data['times'], data['ens'], data['tarEns'], "integrate", "preInptToEns") np.savez(file, dPreA=dPreA, dPreB=dPreB, ePreA=ePreA, ePreB=ePreB) else: ePreA = np.zeros((NPre, N, 1)) if 3 in load: dEns = np.load(file)['dEns'] tauRiseEns = np.load(file)['tauRiseEns'] tauFallEns = np.load(file)['tauFallEns'] fEns = DoubleExp(tauRiseEns, tauFallEns) else: print('readout decoders for ens') spikes = np.zeros((nTrain, int(t / dt), N)) targets = np.zeros((nTrain, int(t / dt), 1)) for n in range(nTrain): stimA, stimB = makeSignal(t, fPre, dt=dt, seed=n) data = go(dPreA=dPreA, dPreB=dPreB, ePreA=ePreA, ePreB=ePreB, NPre=NPre, N=N, t=t, dt=dt, neuron_type=neuron_type, fPre=fPre, fS=fS, stimA=stimA, stimB=stimB, stage=3) spikes[n] = data['ens'] targets[n] = fPre.filt(data['inptB'], dt=dt) dEns, fEns, tauRiseEns, tauFallEns, X2, Y2, error2 = decode( spikes, targets, nTrain, dt=dt, reg=reg, tauRiseMax=tauRiseMax, tauFallMax=tauFallMax, name="integrate") np.savez(file, dPreA=dPreA, dPreB=dPreB, ePreA=ePreA, ePreB=ePreB, dEns=dEns, tauRiseEns=tauRiseEns, tauFallEns=tauFallEns) times = np.arange(0, t * nTrain, dt) plotState(times, X2, Y2, error2, "integrate", "%s_ens" % neuron_type, t * nTrain) if 4 in load: eBio = np.load(file)['eBio'] elif isinstance(neuron_type, Bio): print("encoders from ens2 to ens") eBio = np.zeros((N, N, 1)) for n in range(nEnc): stimA, stimB = makeSignal(t, fPre, dt=dt, seed=n) data = go(dPreA=dPreA, dPreB=dPreB, dEns=dEns, ePreA=ePreA, ePreB=ePreB, eBio=eBio, NPre=NPre, N=N, t=t, dt=dt, neuron_type=neuron_type, fPre=fPre, fEns=fEns, fS=fS, stimA=stimA, stimB=stimB, stage=4) eBio = data['eBio'] plotActivity(t, dt, fS, data['times'], data['ens'], data['tarEns'], "integrate", "Ens2Ens") np.savez(file, dPreA=dPreA, dPreB=dPreB, ePreA=ePreA, ePreB=ePreB, dEns=dEns, tauRiseEns=tauRiseEns, tauFallEns=tauFallEns, eBio=eBio) else: eBio = np.zeros((N, N, 1)) print("testing") errors = np.zeros((nTest)) for test in range(nTest): stimA, stimB = makeSignal(t, fPre, dt=dt, seed=200 + test) data = go(dPreA=dPreA, dEns=dEns, ePreA=ePreA, eBio=eBio, NPre=NPre, N=N, t=t, dt=dt, neuron_type=neuron_type, fPre=fPre, fEns=fEns, fS=fS, stimA=stimA, stimB=stimB, stage=5) A = fEns.filt(data['ens'], dt=dt) X = np.dot(A, dEns) Y = fPre.filt(data['inptB'], dt=dt) U = fPre.filt(Tff * data['inptA'], dt=dt) error = rmse(X, Y) errorU = rmse(X, U) errors[test] = error fig, ax = plt.subplots() # ax.plot(data['times'], U, label="input") ax.plot(data['times'], X, label="estimate") ax.plot(data['times'], Y, label="target") ax.set(xlabel='time', ylabel='state', title='rmse=%.3f' % error, xlim=((0, t)), ylim=((-1, 1))) ax.legend(loc='upper left') sns.despine() fig.savefig("plots/integrate_%s_test%s.pdf" % (neuron_type, test)) plt.close('all') # print('errors:', errors) # np.savez(file, errors=errors, dPreA=dPreA, dPreB=dPreB, ePreA=ePreA, ePreB=ePreB, dEns=dEns, tauRiseEns=tauRiseEns, tauFallEns=tauFallEns, eBio=eBio) return data['times'], X, Y
def run(NPre=100, N=100, t=10, c=None, nTrain=10, nTest=5, dt=0.001, neuron_type=LIF(), fPre=DoubleExp(1e-3, 1e-1), fNMDA=DoubleExp(10.6e-3, 285e-3), reg=1e-2, load=[], file=None, neg=True): if not c: c = t if 0 in load: dPre = np.load(file)['dPre'] else: print('readout decoders for pre') spikesInpt = np.zeros((nTrain, int(t / dt), NPre)) targetsInpt = np.zeros((nTrain, int(t / dt), 1)) for n in range(nTrain): stim = makeSignal(t, fPre, fNMDA, seed=n) data = go(NPre=NPre, N=N, t=t, fPre=fPre, fNMDA=fNMDA, neuron_type=neuron_type, stim=stim) spikesInpt[n] = data['pre'] targetsInpt[n] = fPre.filt(data['inpt'], dt=dt) dPre, X, Y, error = decodeNoF(spikesInpt, targetsInpt, nTrain, fPre, reg=reg) np.savez("data/diffMemory_%s.npz" % neuron_type, dPre=dPre) times = np.arange(0, t * nTrain, 0.001) plotState(times, X, Y, error, "diffMemory_%s" % neuron_type, "pre", t * nTrain) if 1 in load: dDiff = np.load(file)['dDiff'] else: print('readout decoders for diff') spikes = np.zeros((nTrain, int(t / dt), N)) targets = np.zeros((nTrain, int(t / dt), 1)) for n in range(nTrain): stim = makeSignal(t, fPre, fNMDA, seed=n) data = go(NPre=NPre, N=N, t=t, fPre=fPre, fNMDA=fNMDA, neuron_type=neuron_type, stim=stim, dPre=dPre) spikes[n] = data['diff'] targets[n] = fNMDA.filt(fPre.filt(data['inpt'], dt=dt), dt=dt) dDiff, X, Y, error = decodeNoF(spikes, targets, nTrain, fNMDA, reg=reg) np.savez("data/diffMemory_%s.npz" % neuron_type, dPre=dPre, dDiff=dDiff) times = np.arange(0, t * nTrain, dt) plotState(times, X, Y, error, "diffMemory_%s" % neuron_type, "diff", t * nTrain) if 2 in load: dEns = np.load(file)['dEns'] dNeg = np.load(file)['dNeg'] else: print('readout decoders for ens') spikes = np.zeros((nTrain, int(t / dt), N)) targets = np.zeros((nTrain, int(t / dt), 1)) for n in range(nTrain): stim = makeSignal(t, fPre, fNMDA, seed=n) # data = go(NPre=NPre, N=N, t=t, fPre=fPre, fNMDA=fNMDA, neuron_type=neuron_type, stim=stim, dPre=dPre, dDiff=dDiff) data = go(NPre=NPre, N=N, t=t, fPre=fPre, fNMDA=fNMDA, neuron_type=neuron_type, stim=stim, dPre=dPre, dDiff=dDiff, train=True, Tff=0.3) spikes[n] = data['ens'] # targets[n] = fNMDA.filt(fNMDA.filt(fPre.filt(data['inpt'], dt=dt), dt=dt), dt=dt) targets[n] = fNMDA.filt(fPre.filt(data['intg'], dt=dt), dt=dt) dEns, X, Y, error = decodeNoF(spikes, targets, nTrain, fNMDA, reg=reg) dNeg = -np.array(dEns) np.savez("data/diffMemory_%s.npz" % neuron_type, dPre=dPre, dDiff=dDiff, dEns=dEns, dNeg=dNeg) times = np.arange(0, t * nTrain, dt) plotState(times, X, Y, error, "diffMemory_%s" % neuron_type, "ens", t * nTrain) print('testing') vals = np.linspace(-1, 1, nTest) for test in range(nTest): if neg: stim = lambda t: vals[test] if t < c else 0 data = go(NPre=NPre, N=N, t=t, fPre=fPre, fNMDA=fNMDA, neuron_type=neuron_type, stim=stim, dPre=dPre, dDiff=dDiff, dEns=dEns, dNeg=dNeg, c=c) else: stim = makeSignal(t, fPre, fNMDA, seed=test) data = go(NPre=NPre, N=N, t=t, fPre=fPre, fNMDA=fNMDA, neuron_type=neuron_type, stim=stim, dPre=dPre, dDiff=dDiff, dEns=dEns, dNeg=None, c=None, Tff=0.3) aDiff = fNMDA.filt(fPre.filt(data['diff'], dt=dt), dt=dt) aEns = fNMDA.filt(fNMDA.filt(fPre.filt(data['ens'], dt=dt), dt=dt), dt=dt) xhatDiff = np.dot(aDiff, dDiff) xhatEns = np.dot(aEns, dEns) u = fNMDA.filt(fPre.filt(data['inpt'], dt=dt), dt=dt) u2 = fNMDA.filt(fNMDA.filt(fPre.filt(data['inpt'], dt=dt), dt=dt), dt=dt) x = fNMDA.filt(fNMDA.filt(fPre.filt(data['intg'], dt=dt), dt=dt), dt=dt) error = rmse(xhatEns, x) fig, ax = plt.subplots() if neg: ax.plot(data['times'], u2, alpha=0.5, label="input (delayed)") ax.axvline(c, label="cutoff") else: # ax.plot(data['times'], 0.3*u2, alpha=0.5, label="input") ax.plot(data['times'], x, alpha=0.5, label="integral") # ax.plot(data['times'], xhatDiff, label="diff") ax.plot(data['times'], xhatEns, label="ens") ax.set(xlabel='time', ylabel='state', title="rmse=%.3f" % error, xlim=((0, t)), ylim=((-1, 1))) ax.legend(loc='upper left') sns.despine() fig.savefig("plots/diffMemory_%s_test%s.pdf" % (neuron_type, test))
def run(NPre=100, N=100, t=10, nTrain=10, nTest=3, nEnc=10, dt=0.001, c=None, fPre=DoubleExp(1e-3, 1e-1), fNMDA=DoubleExp(10.6e-3, 285e-3), fGABA=DoubleExp(0.5e-3, 1.5e-3), fS=DoubleExp(1e-3, 1e-1), Tff=0.3, Tneg=-1, reg=1e-2, load=[], file=None, DATest=lambda t: 0): if not c: c = t if 0 in load: dPreA = np.load(file)['dPreA'] dPreB = np.load(file)['dPreB'] dPreC = np.load(file)['dPreC'] else: print('readout decoders for pre[A,B,C]') spikesInptA = np.zeros((nTrain, int(t/dt), NPre)) spikesInptB = np.zeros((nTrain, int(t/dt), NPre)) spikesInptC = np.zeros((nTrain, int(t/dt), NPre)) targetsInptA = np.zeros((nTrain, int(t/dt), 1)) targetsInptB = np.zeros((nTrain, int(t/dt), 1)) targetsInptC = np.zeros((nTrain, int(t/dt), 1)) for n in range(nTrain): stimA, stimB = makeSignal(t, fPre, fNMDA, dt=dt, seed=n) data = go( NPre=NPre, N=N, t=t, dt=dt, stimA=stimA, stimB=stimB, stimC=stimB, stage=0) spikesInptA[n] = data['preA'] spikesInptB[n] = data['preB'] spikesInptC[n] = data['preC'] targetsInptA[n] = fPre.filt(data['inptA'], dt=dt) targetsInptB[n] = fPre.filt(data['inptB'], dt=dt) targetsInptC[n] = fPre.filt(data['inptC'], dt=dt) dPreA, XA, YA, errorA = decodeNoF(spikesInptA, targetsInptA, nTrain, fPre, dt=dt, reg=reg) dPreB, XB, YB, errorB = decodeNoF(spikesInptB, targetsInptB, nTrain, fPre, dt=dt, reg=reg) dPreC, XC, YC, errorC = decodeNoF(spikesInptC, targetsInptC, nTrain, fPre, dt=dt, reg=reg) np.savez("data/integrateNMDA.npz", dPreA=dPreA, dPreB=dPreB, dPreC=dPreC) times = np.arange(0, t*nTrain, dt) plotState(times, XA, YA, errorA, "integrateNMDA", "preA", t*nTrain) plotState(times, XB, YB, errorB, "integrateNMDA", "preB", t*nTrain) plotState(times, XC, YC, errorC, "integrateNMDA", "preC", t*nTrain) if 1 in load: ePreA = np.load(file)['ePreA'] ePreB = np.load(file)['ePreB'] ePreC = np.load(file)['ePreC'] else: print("encoders for pre[A,B,C] to [fdfw, ens, inh]") ePreA = np.zeros((NPre, N, 1)) ePreB = np.zeros((NPre, N, 1)) ePreC = np.zeros((NPre, N, 1)) for n in range(nEnc): stimA, stimB = makeSignal(t, fPre, fNMDA, dt=dt, seed=n) data = go( NPre=NPre, N=N, t=t, dt=dt, stimA=stimA, stimB=stimB, stimC=stimB, dPreA=dPreA, dPreB=dPreB, dPreC=dPreC, ePreA=ePreA, ePreB=ePreB, ePreC=ePreC, stage=1) ePreA = data['ePreA'] ePreB = data['ePreB'] ePreC = data['ePreC'] plotActivity(t, dt, fS, data['times'], data['fdfw'], data['tarFdfw'], "integrateNMDA", "preAFdfw") plotActivity(t, dt, fS, data['times'], data['ens'], data['tarEns'], "integrateNMDA", "preBEns") plotActivity(t, dt, fS, data['times'], data['inh'], data['tarInh'], "integrateNMDA", "preCInh") np.savez("data/integrateNMDA.npz", dPreA=dPreA, dPreB=dPreB, dPreC=dPreC, ePreA=ePreA, ePreB=ePreB, ePreC=ePreC) if 2 in load: dFdfw = np.load(file)['dFdfw'] dInh = np.load(file)['dInh'] else: print('readout decoders for fdfw and inh') spikesFdfw = np.zeros((nTrain, int(t/dt), N)) spikesInh = np.zeros((nTrain, int(t/dt), N)) targetsFdfw = np.zeros((nTrain, int(t/dt), 1)) targetsInh = np.zeros((nTrain, int(t/dt), 1)) for n in range(nTrain): stimA, stimB = makeSignal(t, fPre, fNMDA, dt=dt, seed=n) stimC = lambda x: 0.5+0.5*np.sin(2*np.pi*x/t) data = go( NPre=NPre, N=N, t=t, dt=dt, stimA=stimA, stimB=stimB, stimC=stimC, dPreA=dPreA, dPreB=dPreB, dPreC=dPreC, ePreA=ePreA, ePreB=ePreB, ePreC=ePreC, stage=2) spikesFdfw[n] = data['fdfw'] spikesInh[n] = data['inh'] targetsFdfw[n] = fNMDA.filt(Tff*fPre.filt(data['inptA'], dt=dt)) targetsInh[n] = fGABA.filt(fPre.filt(data['inptC'], dt=dt)) dFdfw, X1, Y1, error1 = decodeNoF(spikesFdfw, targetsFdfw, nTrain, fNMDA, dt=dt, reg=reg) dInh, X2, Y2, error2 = decodeNoF(spikesInh, targetsInh, nTrain, fGABA, dt=dt, reg=reg) np.savez("data/integrateNMDA.npz", dPreA=dPreA, dPreB=dPreB, dPreC=dPreC, dFdfw=dFdfw, dInh=dInh, ePreA=ePreA, ePreB=ePreB, ePreC=ePreC) times = np.arange(0, t*nTrain, dt) plotState(times, X1, Y1, error1, "integrateNMDA", "fdfw", t*nTrain) plotState(times, X2, Y2, error2, "integrateNMDA", "inh", t*nTrain) if 3 in load: eFdfw = np.load(file)['eFdfw'] else: print("encoders for fdfw-to-ens") eFdfw = np.zeros((N, N, 1)) for n in range(nEnc): stimA, stimB = makeSignal(t, fPre, fNMDA, dt=dt, seed=n) data = go( NPre=NPre, N=N, t=t, dt=dt, stimA=stimA, stimB=stimB, Tff=Tff, dPreA=dPreA, dPreB=dPreB, dFdfw=dFdfw, ePreA=ePreA, ePreB=ePreB, eFdfw=eFdfw, stage=3) eFdfw = data['eFdfw'] plotActivity(t, dt, fS, data['times'], data['ens'], data['tarEns'], "integrateNMDA", "fdfwEns") np.savez("data/integrateNMDA.npz", dPreA=dPreA, dPreB=dPreB, dPreC=dPreC, dFdfw=dFdfw, dInh=dInh, ePreA=ePreA, ePreB=ePreB, ePreC=ePreC, eFdfw=eFdfw) if 4 in load: dBio = np.load(file)['dBio'] dNeg = np.load(file)['dNeg'] # dNeg = -np.array(dBio) else: print('readout decoders for ens') spikes = np.zeros((nTrain, int(t/dt), N)) targets = np.zeros((nTrain, int(t/dt), 1)) for n in range(nTrain): stimA, stimB = makeSignal(t, fPre, fNMDA, dt=dt, seed=n) data = go( NPre=NPre, N=N, t=t, dt=dt, stimA=stimA, stimB=stimB, dPreA=dPreA, dPreB=dPreB, dFdfw=dFdfw, ePreA=ePreA, ePreB=ePreB, eFdfw=eFdfw, stage=4) spikes[n] = data['ens'] targets[n] = fNMDA.filt(fPre.filt(data['inptB'])) # targets[n] = fNMDA.filt( # fPre.filt(fNMDA.filt(data['inptB'])) + # fNMDA.filt(Tff*fPre.filt(data['inptA']))) dBio, X, Y, error = decodeNoF(spikes, targets, nTrain, fNMDA, dt=dt, reg=reg) dNeg = -np.array(dBio) np.savez("data/integrateNMDA.npz", dPreA=dPreA, dPreB=dPreB, dPreC=dPreC, dFdfw=dFdfw, dInh=dInh, dBio=dBio, dNeg=dNeg, ePreA=ePreA, ePreB=ePreB, ePreC=ePreC, eFdfw=eFdfw) times = np.arange(0, t*nTrain, dt) plotState(times, X, Y, error, "integrateNMDA", "ens", t*nTrain) if 5 in load: eBio = np.load(file)['eBio'] else: print("encoders from ens to ens") eBio = np.zeros((N, N, 1)) for n in range(nEnc): stimA, stimB = makeSignal(t, fPre, fNMDA, dt=dt, seed=n) data = go( NPre=NPre, N=N, t=t, dt=dt, stimA=stimA, stimB=stimB, dPreA=dPreA, dPreB=dPreB, dFdfw=dFdfw, dBio=dBio, ePreA=ePreA, ePreB=ePreB, eFdfw=eFdfw, eBio=eBio, stage=5) eBio = data['eBio'] plotActivity(t, dt, fS, data['times'], data['ens'], data['tarEns'], "integrateNMDA", "ensEns") np.savez("data/integrateNMDA.npz", dPreA=dPreA, dPreB=dPreB, dPreC=dPreC, dFdfw=dFdfw, dBio=dBio, dInh=dInh, dNeg=dNeg, ePreA=ePreA, ePreB=ePreB, ePreC=ePreC, eFdfw=eFdfw, eBio=eBio) if 6 in load: eNeg = np.load(file)['eNeg'] eInh = np.load(file)['eInh'] else: print("encoders from [ens, inh] to fdfw") eNeg = np.zeros((N, N, 1)) eInh = np.zeros((N, N, 1)) for n in range(nEnc): stimA, stimB = makeSignal(t, fPre, fNMDA, dt=dt, seed=n) stimC = lambda t: 0 if t<c else 1 data = go( NPre=NPre, N=N, t=t, dt=dt, stimA=stimA, stimB=stimB, stimC=stimC, dPreA=dPreA, dPreB=dPreB, dPreC=dPreC, dFdfw=dFdfw, dBio=dBio, dNeg=dNeg, dInh=dInh, ePreA=ePreA, ePreB=ePreB, ePreC=ePreC, eFdfw=eFdfw, eBio=eBio, eNeg=eNeg, eInh=eInh, stage=6) eNeg = data['eNeg'] eInh = data['eInh'] plotActivity(t, dt, fS, data['times'], data['fdfw2'], data['tarFdfw2'], "integrateNMDA", "ensFdfw") plotActivity(t, dt, fS, data['times'], data['fdfw4'], data['tarFdfw4'], "integrateNMDA", "inhFdfw") np.savez("data/integrateNMDA.npz", dPreA=dPreA, dPreB=dPreB, dPreC=dPreC, dFdfw=dFdfw, dBio=dBio, dNeg=dNeg, dInh=dInh, ePreA=ePreA, ePreB=ePreB, ePreC=ePreC, eFdfw=eFdfw, eBio=eBio, eNeg=eNeg, eInh=eInh) # eNeg = None # eInh = None print("testing") vals = np.linspace(-1, 1, nTest) for test in range(nTest): stimA = lambda t: vals[test] if t<c else 0 stimC = lambda t: 0 if t<c else 1 data = go( NPre=NPre, N=N, t=t, dt=dt, stimA=stimA, stimC=stimC, dPreA=dPreA, dPreC=dPreC, dFdfw=dFdfw, dBio=dBio, dNeg=Tneg*dNeg, dInh=dInh, ePreA=ePreA, ePreC=ePreC, eFdfw=eFdfw, eBio=eBio, eNeg=eNeg, eInh=eInh, stage=7, c=c, DA=DATest) aFdfw = fNMDA.filt(data['fdfw'], dt=dt) aBio = fNMDA.filt(data['ens'], dt=dt) aInh = fGABA.filt(data['inh'], dt=dt) xhatFdfw = np.dot(aFdfw, dFdfw)/Tff xhatBio = np.dot(aBio, dBio) xhatInh = np.dot(aInh, dInh) xFdfw = fNMDA.filt(fNMDA.filt(fPre.filt(data['inptA'], dt=dt), dt=dt), dt=dt) xBio = xFdfw[int(c/dt)] * np.ones_like(data['times']) errorBio = rmse(xhatBio[int(c/dt):], xBio[int(c/dt):]) fig, ax = plt.subplots() ax.plot(data['times'], xFdfw, alpha=0.5, label="input (filtered)") ax.plot(data['times'], xhatFdfw, alpha=0.5, label="fdfw") ax.plot(data['times'], xBio, label="target") ax.plot(data['times'], xhatBio, alpha=0.5, label="ens, rmse=%.3f"%errorBio) ax.plot(data['times'], xhatInh, alpha=0.5, label="inh") ax.axvline(c, label="cutoff") ax.set(xlabel='time', ylabel='state', xlim=((0, t)), ylim=((-1.5, 1.5))) ax.legend(loc='upper left') sns.despine() fig.savefig("plots/integrateNMDA_testFlat%s.pdf"%test) stimA, stimB = makeSignal(t, fPre, fNMDA, dt=dt, seed=200+test) stimC = lambda t: 0 data = go(NPre=NPre, N=N, t=t, dt=dt, stimA=stimA, stimB=stimB, stimC=stimC, Tff=Tff, dPreA=dPreA, dPreC=dPreC, dFdfw=dFdfw, dBio=dBio, dNeg=None, dInh=dInh, ePreA=ePreA, ePreC=ePreC, eFdfw=eFdfw, eBio=eBio, eNeg=None, eInh=eInh, stage=7, DA=DATest) aFdfw = fNMDA.filt(data['fdfw'], dt=dt) aBio = fNMDA.filt(data['ens'], dt=dt) xhatFdfw = np.dot(aFdfw, dFdfw) xhatBio = np.dot(aBio, dBio) xhatNeg = np.dot(aBio, dNeg) xFdfw = fNMDA.filt(fPre.filt(data['inptA'], dt=dt), dt=dt) xBio = fNMDA.filt(fPre.filt(data['inptB'], dt=dt), dt=dt) errorBio = rmse(xhatBio, xBio) fig, ax = plt.subplots() ax.plot(data['times'], xhatFdfw, alpha=0.5, label="fdfw") ax.plot(data['times'], fNMDA.filt(Tff*xFdfw, dt=dt), alpha=0.5, label="input (filtered)") ax.plot(data['times'], xBio, label="target") ax.plot(data['times'], xhatBio, alpha=0.5, label="ens, rmse=%.3f"%errorBio) ax.set(xlabel='time', ylabel='state', xlim=((0, t)), ylim=((-1.5, 1.5))) ax.legend(loc='upper left') sns.despine() fig.savefig("plots/integrateNMDA_testWhite%s.pdf"%test) plt.close('all')
def run(NPre=100, N=100, t=10, nTrain=10, nTest=3, nAttr=5, nEnc=10, dt=0.001, c=None, fPre=DoubleExp(1e-3, 1e-1), fNMDA=DoubleExp(10.6e-3, 285e-3), fGABA=DoubleExp(0.5e-3, 1.5e-3), fS=DoubleExp(2e-2, 1e-1), Tff=0.3, reg=1e-2, load=[], file=None): if not c: c = t if 0 in load: dPreA = np.load(file)['dPreA'] dPreB = np.load(file)['dPreB'] dPreC = np.load(file)['dPreC'] dPreD = np.load(file)['dPreD'] else: print('readout decoders for pre') spikesInptA = np.zeros((nTrain, int(t / dt), NPre)) spikesInptB = np.zeros((nTrain, int(t / dt), NPre)) spikesInptC = np.zeros((nTrain, int(t / dt), NPre)) spikesInptD = np.zeros((nTrain, int(t / dt), NPre)) targetsInptA = np.zeros((nTrain, int(t / dt), 1)) targetsInptB = np.zeros((nTrain, int(t / dt), 1)) targetsInptC = np.zeros((nTrain, int(t / dt), 1)) targetsInptD = np.zeros((nTrain, int(t / dt), 1)) for n in range(nTrain): stimA, _ = makeSignal(t, fPre, fNMDA, nAttr, dt=dt, seed=n) stimB = stimA stimC = lambda t: 0 if t < c else 1 stimD = lambda t: 1e-1 data = go(NPre=NPre, N=N, t=t, dt=dt, stimA=stimA, stimB=stimB, stimC=stimC, stimD=stimD, stage=None) spikesInptA[n] = data['preA'] spikesInptB[n] = data['preB'] spikesInptC[n] = data['preC'] spikesInptD[n] = data['preD'] targetsInptA[n] = fPre.filt(data['inptA'], dt=dt) targetsInptB[n] = fPre.filt(data['inptB'], dt=dt) targetsInptC[n] = fPre.filt(data['inptC'], dt=dt) targetsInptD[n] = fPre.filt(data['inptD'], dt=dt) dPreA, XA, YA, errorA = decodeNoF(spikesInptA, targetsInptA, nTrain, fPre, dt=dt, reg=reg) dPreB, XB, YB, errorB = decodeNoF(spikesInptB, targetsInptB, nTrain, fPre, dt=dt, reg=reg) dPreC, XC, YC, errorC = decodeNoF(spikesInptC, targetsInptC, nTrain, fPre, dt=dt, reg=reg) dPreD, XD, YD, errorD = decodeNoF(spikesInptD, targetsInptD, nTrain, fPre, dt=dt, reg=reg) np.savez("data/integrateAttractors.npz", dPreA=dPreA, dPreB=dPreB, dPreC=dPreC, dPreD=dPreD) times = np.arange(0, t * nTrain, dt) plotState(times, XA, YA, errorA, "integrateAttractors", "preA", t * nTrain) plotState(times, XB, YB, errorB, "integrateAttractors", "preB", t * nTrain) plotState(times, XC, YC, errorC, "integrateAttractors", "preC", t * nTrain) plotState(times, XD, YD, errorD, "integrateAttractors", "preD", t * nTrain) if 0 in load: ePreDEns = np.load(file)['ePreDEns'] else: print("encoders for preD to ens") ePreDEns = np.zeros((NPre, N, 1)) for n in range(nEnc): data = go(NPre=NPre, N=N, t=t, dt=dt, dPreD=dPreD, ePreDEns=ePreDEns, stage=0) ePreDEns = data['ePreDEns'] plotActivity(t, dt, fS, data['times'], data['ens'], data['tarEns'], "integrateAttractors", "preDEns") np.savez("data/integrateAttractors.npz", dPreA=dPreA, dPreB=dPreB, dPreC=dPreC, dPreD=dPreD, ePreDEns=ePreDEns) if 1 in load: ePreAFdfw = np.load(file)['ePreAFdfw'] ePreBEns = np.load(file)['ePreBEns'] ePreCOff = np.load(file)['ePreCOff'] else: print("encoders for [preA, preB, preC] to [fdfw, ens, off]") ePreAFdfw = np.zeros((NPre, N, 1)) ePreBEns = np.zeros((NPre, N, 1)) ePreCOff = np.zeros((NPre, N, 1)) for n in range(nEnc): stimA, _ = makeSignal(t, fPre, fNMDA, nAttr, dt=dt, seed=n) stimB = stimA stimC = lambda t: 0 if t < c else 1 data = go(NPre=NPre, N=N, t=t, dt=dt, stimA=stimA, stimB=stimB, stimC=stimB, dPreA=dPreA, dPreB=dPreB, dPreC=dPreC, dPreD=dPreD, ePreAFdfw=ePreAFdfw, ePreBEns=ePreBEns, ePreCOff=ePreCOff, ePreDEns=ePreDEns, stage=1) ePreAFdfw = data['ePreAFdfw'] ePreBEns = data['ePreBEns'] ePreCOff = data['ePreCOff'] plotActivity(t, dt, fS, data['times'], data['fdfw'], data['tarFdfw'], "integrateAttractors", "preAFdfw") plotActivity(t, dt, fS, data['times'], data['ens'], data['tarEns'], "integrateAttractors", "preBEns") plotActivity(t, dt, fS, data['times'], data['off'], data['tarOff'], "integrateAttractors", "preCOff") np.savez("data/integrateAttractors.npz", dPreA=dPreA, dPreB=dPreB, dPreC=dPreC, dPreD=dPreD, ePreDEns=ePreDEns, ePreAFdfw=ePreAFdfw, ePreBEns=ePreBEns, ePreCOff=ePreCOff) if 2 in load: dFdfw = np.load(file)['dFdfw'] dOff = np.load(file)['dOff'] else: print('readout decoders for fdfw and off') spikesFdfw = np.zeros((nTrain, int(t / dt), N)) spikesOff = np.zeros((nTrain, int(t / dt), N)) targetsFdfw = np.zeros((nTrain, int(t / dt), 1)) targetsOff = np.zeros((nTrain, int(t / dt), 1)) for n in range(nTrain): stimA, stimB = makeSignal(t, fPre, fNMDA, nAttr, dt=dt, seed=n) stimC = lambda t: 0 if t < c else 1 data = go(NPre=NPre, N=N, t=t, dt=dt, stimA=stimA, stimB=stimB, stimC=stimC, dPreA=dPreA, dPreB=dPreB, dPreC=dPreC, dPreD=dPreD, ePreAFdfw=ePreAFdfw, ePreBEns=ePreBEns, ePreCOff=ePreCOff, ePreDEns=ePreDEns, stage=2) spikesFdfw[n] = data['fdfw'] spikesOff[n] = data['off'] targetsFdfw[n] = fNMDA.filt(Tff * fPre.filt(data['inptA'])) # targetsFdfw[n] = fNMDA.filt(fPre.filt(data['inptB'])) targetsOff[n] = fGABA.filt(fPre.filt(data['inptC'])) dFdfw, X1, Y1, error1 = decodeNoF(spikesFdfw, targetsFdfw, nTrain, fNMDA, dt=dt, reg=reg) dOff, X2, Y2, error2 = decodeNoF(spikesOff, targetsOff, nTrain, fGABA, dt=dt, reg=reg) np.savez("data/integrateAttractors.npz", dPreA=dPreA, dPreB=dPreB, dPreC=dPreC, dPreD=dPreD, dFdfw=dFdfw, dOff=dOff, ePreDEns=ePreDEns, ePreAFdfw=ePreAFdfw, ePreBEns=ePreBEns, ePreCOff=ePreCOff) times = np.arange(0, t * nTrain, dt) plotState(times, X1, Y1, error1, "integrateAttractors", "fdfw", t * nTrain) plotState(times, X2, Y2, error2, "integrateAttractors", "off", t * nTrain) if 3 in load: eFdfwEns = np.load(file)['eFdfwEns'] else: print("encoders for fdfw-to-ens") eFdfwEns = np.zeros((N, N, 1)) for n in range(nEnc): stimA, stimB = makeSignal(t, fPre, fNMDA, nAttr, dt=dt, seed=n) data = go(NPre=NPre, N=N, t=t, dt=dt, Tff=Tff, stimA=stimA, stimB=stimB, dPreA=dPreA, dPreB=dPreB, dPreD=dPreD, dFdfw=dFdfw, ePreAFdfw=ePreAFdfw, ePreBEns=ePreBEns, ePreDEns=ePreDEns, eFdfwEns=eFdfwEns, stage=3) eFdfwEns = data['eFdfwEns'] plotActivity(t, dt, fS, data['times'], data['ens'], data['tarEns'], "integrateAttractors", "fdfwEns") np.savez("data/integrateAttractors.npz", dPreA=dPreA, dPreB=dPreB, dPreC=dPreC, dPreD=dPreD, dFdfw=dFdfw, dOff=dOff, ePreDEns=ePreDEns, ePreAFdfw=ePreAFdfw, ePreBEns=ePreBEns, ePreCOff=ePreCOff, eFdfwEns=eFdfwEns) if 4 in load: dEns = np.load(file)['dEns'] else: print('readout decoders for ens') spikes = np.zeros((nTrain, int(t / dt), N)) targets = np.zeros((nTrain, int(t / dt), 1)) for n in range(nTrain): stimA, stimB = makeSignal(t, fPre, fNMDA, nAttr, dt=dt, seed=n) data = go(NPre=NPre, N=N, t=t, dt=dt, stimA=stimA, stimB=stimB, dPreA=dPreA, dPreB=dPreB, dPreD=dPreD, dFdfw=dFdfw, ePreAFdfw=ePreAFdfw, ePreBEns=ePreBEns, ePreDEns=ePreDEns, eFdfwEns=eFdfwEns, stage=4) spikes[n] = data['ens'] # targets[n] = data['inptB'] # stimA rounded to nearest attractor targets[n] = fNMDA.filt(fPre.filt( data['inptB'])) # stimA rounded to nearest attractor dEns, X, Y, error = decodeNoF(spikes, targets, nTrain, fNMDA, dt=dt, reg=reg) np.savez("data/integrateAttractors.npz", dPreA=dPreA, dPreB=dPreB, dPreC=dPreC, dPreD=dPreD, dFdfw=dFdfw, dOff=dOff, dEns=dEns, ePreDEns=ePreDEns, ePreAFdfw=ePreAFdfw, ePreBEns=ePreBEns, ePreCOff=ePreCOff, eFdfwEns=eFdfwEns) times = np.arange(0, t * nTrain, dt) plotState(times, X, Y, error, "integrateAttractors", "ens", t * nTrain) if 5 in load: eEnsEns = np.load(file)['eEnsEns'] else: print("encoders from ens to ens") eEnsEns = np.zeros((N, N, 1)) for n in range(nEnc): stimA, stimB = makeSignal(t, fPre, fNMDA, nAttr, dt=dt, seed=n) data = go(NPre=NPre, N=N, t=t, dt=dt, stimA=stimA, stimB=stimB, dPreA=dPreA, dPreB=dPreB, dPreD=dPreD, dFdfw=dFdfw, dEns=dEns, ePreAFdfw=ePreAFdfw, ePreBEns=ePreBEns, ePreDEns=ePreDEns, eFdfwEns=eFdfwEns, eEnsEns=eEnsEns, stage=5) eEnsEns = data['eEnsEns'] plotActivity(t, dt, fS, data['times'], data['ens'], data['tarEns'], "integrateAttractors", "ensEns") np.savez("data/integrateAttractors.npz", dPreA=dPreA, dPreB=dPreB, dPreC=dPreC, dPreD=dPreD, dFdfw=dFdfw, dOff=dOff, dEns=dEns, ePreDEns=ePreDEns, ePreAFdfw=ePreAFdfw, ePreBEns=ePreBEns, ePreCOff=ePreCOff, eFdfwEns=eFdfwEns, eEnsEns=eEnsEns) print("testing") eOffFdfw = -1e2 * np.ones((N, N, 1)) vals = np.linspace(-1, 1, nTest) att = np.linspace(-1, 1, nAttr) for test in range(nTest): stimA = lambda t: vals[test] if t < c else 0 stimB = lambda t: closestAttractor(vals[test], att) stimC = lambda t: 0 if t < c else 1 DATest = lambda t: 0 data = go(NPre=NPre, N=N, t=t, dt=dt, stimA=stimA, stimB=stimB, stimC=stimC, dPreA=dPreA, dPreC=dPreC, dPreD=dPreD, dFdfw=dFdfw, dEns=dEns, dOff=dOff, ePreAFdfw=ePreAFdfw, ePreCOff=ePreCOff, ePreDEns=ePreDEns, eFdfwEns=eFdfwEns, eEnsEns=eEnsEns, eOffFdfw=eOffFdfw, stage=9, c=c, DA=DATest) aFdfw = fNMDA.filt(data['fdfw']) aEns = fNMDA.filt(data['ens']) aOff = fGABA.filt(data['off']) xhatFdfw = np.dot(aFdfw, dFdfw) / Tff xhatEns = np.dot(aEns, dEns) xhatOff = np.dot(aOff, dOff) xFdfw = fNMDA.filt(fPre.filt(data['inptA'])) # xEns = xFdfw[int(c/dt)] * np.ones_like(data['times']) xEns = fNMDA.filt(fPre.filt(data['inptB'])) # xEns = data['inptB'] errorEns = rmse(xhatEns[int(c / dt):], xEns[int(c / dt):]) fig, ax = plt.subplots() ax.plot(data['times'], xFdfw, alpha=0.5, label="input (filtered)") ax.plot(data['times'], xhatFdfw, alpha=0.5, label="fdfw") ax.plot(data['times'], xEns, label="target") ax.plot(data['times'], xhatEns, alpha=0.5, label="ens, rmse=%.3f" % errorEns) ax.plot(data['times'], xhatOff, alpha=0.5, label="off") ax.axvline(c, label="cutoff") ax.set(xlabel='time', ylabel='state', xlim=((0, t)), ylim=((-1.5, 1.5))) ax.legend(loc='upper left') sns.despine() fig.savefig("plots/integrateAttractors_testFlat%s.pdf" % test)
def run(NPre=300, NBias=100, N=30, t=10, nTrain=10, nTest=5, nEnc=10, dt=0.001, fPre=DoubleExp(1e-3, 1e-1), fNMDA=DoubleExp(10.6e-3, 285e-3), fGABA=DoubleExp(0.5e-3, 1.5e-3), fS=DoubleExp(2e-2, 1e-1), DATrain=lambda t: 0, DATest=lambda t: 0, Tff=0.3, reg=1e-2, load=[], file=None): if 0 in load: dPre = np.load(file)['dPre'] else: print('readout decoders for [preA, preB, preC]') spikesInpt = np.zeros((nTrain, int(t / dt), NPre)) targetsInpt = np.zeros((nTrain, int(t / dt), 1)) for n in range(nTrain): stimA, stimB = makeSignal(t, fPre, fNMDA, dt=dt, seed=n) data = go(NPre=NPre, NBias=NBias, N=N, t=t, dt=dt, DA=DATrain, stimA=stimA, stimB=stimB, stage=0) spikesInpt[n] = data['preA'] targetsInpt[n] = fPre.filt(data['inptA'], dt=dt) dPre, X, Y, error = decodeNoF(spikesInpt, targetsInpt, nTrain, fPre, dt=dt, reg=reg) np.savez( "data/integrateNMDAbias2.npz", dPre=dPre, ) times = np.arange(0, t * nTrain, dt) plotState(times, X, Y, error, "integrateNMDAbias2", "pre", t * nTrain) if 1 in load: ePreFdfw = np.load(file)['ePreFdfw'] ePreBias = np.load(file)['ePreBias'] else: print("encoders for [preA, preC] to [fdfw, bias]") ePreFdfw = np.zeros((NPre, N, 1)) ePreBias = np.zeros((NPre, NBias, 1)) for n in range(nEnc): stimA, stimB = makeSignal(t, fPre, fNMDA, dt=dt, seed=n) data = go(NPre=NPre, NBias=NBias, N=N, t=t, dt=dt, DA=DATrain, stimA=stimA, stimB=stimB, dPre=dPre, ePreFdfw=ePreFdfw, ePreBias=ePreBias, stage=1) ePreFdfw = data['ePreFdfw'] ePreBias = data['ePreBias'] plotActivity(t, dt, fS, data['times'], data['fdfw'], data['tarFdfw'], "integrateNMDAbias2", "pre_fdfw") plotActivity(t, dt, fS, data['times'], data['bias'], data['tarBias'], "integrateNMDAbias2", "pre_bias") np.savez("data/integrateNMDAbias2.npz", dPre=dPre, ePreFdfw=ePreFdfw, ePreBias=ePreBias) if 2 in load: dFdfw = np.load(file)['dFdfw'] dBias = np.load(file)['dBias'] else: print('readout decoders for fdfw and bias') spikesFdfw = np.zeros((nTrain, int(t / dt), N)) spikesBias = np.zeros((nTrain, int(t / dt), NBias)) targetsFdfw = np.zeros((nTrain, int(t / dt), 1)) targetsBias = np.zeros((nTrain, int(t / dt), 1)) for n in range(nTrain): stimA, _ = makeSignal(t, fPre, fNMDA, dt=dt, seed=n) data = go(NPre=NPre, NBias=NBias, N=N, t=t, dt=dt, DA=DATrain, stimA=stimA, dPre=dPre, ePreFdfw=ePreFdfw, ePreBias=ePreBias, stage=2) spikesFdfw[n] = data['fdfw'] spikesBias[n] = data['bias'] targetsFdfw[n] = fNMDA.filt(Tff * fPre.filt(data['inptA'], dt=dt)) targetsBias[n] = fGABA.filt(fPre.filt(data['inptC'], dt=dt)) dFdfw, X1, Y1, error1 = decodeNoF(spikesFdfw, targetsFdfw, nTrain, fNMDA, dt=dt, reg=reg) dBias, X2, Y2, error2 = decodeNoF(spikesBias, targetsBias, nTrain, fGABA, dt=dt, reg=reg) np.savez("data/integrateNMDAbias2.npz", dPre=dPre, ePreFdfw=ePreFdfw, ePreBias=ePreBias, dFdfw=dFdfw, dBias=dBias) times = np.arange(0, t * nTrain, dt) plotState(times, X1, Y1, error1, "integrateNMDAbias2", "fdfw", t * nTrain) plotState(times, X2, Y2, error2, "integrateNMDAbias2", "bias", t * nTrain) if 3 in load: eBiasEns = np.load(file)['eBiasEns'] else: print("encoders for [bias] to ens") eBiasEns = np.zeros((NBias, N, 1)) for n in range(nEnc): stimA, stimB = makeSignal(t, fPre, fNMDA, dt=dt, seed=n) data = go(NPre=NPre, NBias=NBias, N=N, t=t, dt=dt, DA=DATrain, stimA=stimA, dPre=dPre, dBias=dBias, ePreBias=ePreBias, eBiasEns=eBiasEns, stage=3) eBiasEns = data['eBiasEns'] plotActivity(t, dt, fS, data['times'], data['ens'], data['tarEns'], "integrateNMDAbias2", "bias_ens") np.savez( "data/integrateNMDAbias2.npz", dPre=dPre, ePreFdfw=ePreFdfw, ePreBias=ePreBias, dFdfw=dFdfw, dBias=dBias, eBiasEns=eBiasEns, ) if 4 in load: ePreEns = np.load(file)['ePreEns'] else: print("encoders for [preB] to [ens]") ePreEns = np.zeros((NPre, N, 1)) for n in range(nEnc): stimA, stimB = makeSignal(t, fPre, fNMDA, dt=dt, seed=n) data = go(NPre=NPre, NBias=NBias, N=N, t=t, dt=dt, DA=DATrain, stimA=stimA, stimB=stimB, dPre=dPre, dBias=dBias, ePreBias=ePreBias, eBiasEns=eBiasEns, ePreEns=ePreEns, stage=4) ePreEns = data['ePreEns'] plotActivity(t, dt, fS, data['times'], data['ens'], data['tarEns'], "integrateNMDAbias2", "pre_ens") np.savez( "data/integrateNMDAbias2.npz", dPre=dPre, ePreFdfw=ePreFdfw, ePreBias=ePreBias, dFdfw=dFdfw, dBias=dBias, eBiasEns=eBiasEns, ePreEns=ePreEns, ) if 5 in load: eFdfwEns = np.load(file)['eFdfwEns'] else: print("encoders for [fdfw] to ens") eFdfwEns = np.zeros((N, N, 1)) for n in range(nEnc): # stimA, stimB = makeSignal(t, fPre, fNMDA, dt=dt, seed=n) stimA, stimB = makeCarrier(t, fPre, fNMDA, dt=dt, seed=n) data = go(NPre=NPre, NBias=NBias, N=N, t=t, dt=dt, DA=DATrain, stimA=stimA, Tff=Tff, dPre=dPre, dFdfw=dFdfw, dBias=dBias, ePreFdfw=ePreFdfw, ePreEns=ePreEns, eFdfwEns=eFdfwEns, eBiasEns=eBiasEns, stage=5) eFdfwEns = data['eFdfwEns'] plotActivity(t, dt, fS, data['times'], data['ens'], data['tarEns'], "integrateNMDAbias2", "fdfw_ens") np.savez( "data/integrateNMDAbias2.npz", dPre=dPre, ePreFdfw=ePreFdfw, ePreEns=ePreEns, ePreBias=ePreBias, dFdfw=dFdfw, dBias=dBias, eFdfwEns=eFdfwEns, eBiasEns=eBiasEns, ) if 6 in load: dEns = np.load(file)['dEns'] else: print('readout decoders for ens') spikes = np.zeros((nTrain, int(t / dt), N)) targets = np.zeros((nTrain, int(t / dt), 1)) for n in range(nTrain): # stimA, stimB = makeSignal(t, fPre, fNMDA, dt=dt, seed=n) stimA, stimB = makeCarrier(t, fPre, fNMDA, dt=dt, seed=n) data = go(NPre=NPre, NBias=NBias, N=N, t=t, dt=dt, DA=DATrain, stimA=stimA, stimB=stimB, dPre=dPre, dFdfw=dFdfw, dBias=dBias, ePreFdfw=ePreFdfw, ePreEns=ePreEns, eFdfwEns=eFdfwEns, eBiasEns=eBiasEns, stage=6) spikes[n] = data['ens'] targets[n] = fNMDA.filt(fPre.filt(data['inptB'])) dEns, X, Y, error = decodeNoF(spikes, targets, nTrain, fNMDA, dt=dt, reg=reg) np.savez("data/integrateNMDAbias2.npz", dPre=dPre, ePreFdfw=ePreFdfw, ePreEns=ePreEns, ePreBias=ePreBias, dFdfw=dFdfw, dBias=dBias, eFdfwEns=eFdfwEns, eBiasEns=eBiasEns, dEns=dEns) times = np.arange(0, t * nTrain, dt) plotState(times, X, Y, error, "integrateNMDAbias2", "ens", t * nTrain) if 7 in load: eEnsEns = np.load(file)['eEnsEns'] else: print("encoders from ens to ens") eEnsEns = np.zeros((N, N, 1)) for n in range(nEnc): # stimA, stimB = makeSignal(t, fPre, fNMDA, dt=dt, seed=n) stimA, stimB = makeCarrier(t, fPre, fNMDA, dt=dt, seed=n) data = go(NPre=NPre, NBias=NBias, N=N, t=t, dt=dt, DA=DATrain, stimA=stimA, stimB=stimB, dPre=dPre, dFdfw=dFdfw, dBias=dBias, dEns=dEns, ePreFdfw=ePreFdfw, ePreEns=ePreEns, eFdfwEns=eFdfwEns, eBiasEns=eBiasEns, eEnsEns=eEnsEns, stage=7) eEnsEns = data['eEnsEns'] plotActivity(t, dt, fS, data['times'], data['ens'], data['tarEns'], "integrateNMDAbias2", "ens_ens") np.savez("data/integrateNMDAbias2.npz", dPre=dPre, ePreFdfw=ePreFdfw, ePreEns=ePreEns, ePreBias=ePreBias, dFdfw=dFdfw, dBias=dBias, eFdfwEns=eFdfwEns, eBiasEns=eBiasEns, dEns=dEns, eEnsEns=eEnsEns) print("testing") vals = np.linspace(-1, 1, nTest) for test in range(nTest): stimA, stimB = makeSignal(t, fPre, fNMDA, dt=dt, seed=200 + test) data = go(NPre=NPre, NBias=NBias, N=N, t=t, dt=dt, DA=DATest, stimA=stimA, stimB=stimB, Tff=Tff, dPre=dPre, dFdfw=dFdfw, dBias=dBias, dEns=dEns, ePreFdfw=ePreFdfw, ePreEns=ePreEns, eFdfwEns=eFdfwEns, eBiasEns=eBiasEns, eEnsEns=eEnsEns, stage=8) aFdfw = fNMDA.filt(data['fdfw'], dt=dt) aBio = fNMDA.filt(data['ens'], dt=dt) xhatFdfw = np.dot(aFdfw, dFdfw) xhatEns = np.dot(aBio, dEns) xFdfw = fNMDA.filt(fPre.filt(data['inptA'], dt=dt), dt=dt) xEns = fNMDA.filt(fPre.filt(data['inptB'], dt=dt), dt=dt) errorBio = rmse(xhatEns, xEns) fig, ax = plt.subplots() ax.plot(data['times'], xhatFdfw, alpha=0.5, label="fdfw") ax.plot(data['times'], Tff * xFdfw, alpha=0.5, label="input") ax.plot(data['times'], xEns, label="target") ax.plot(data['times'], xhatEns, alpha=0.5, label="ens, rmse=%.3f" % errorBio) ax.set(xlabel='time', ylabel='state', xlim=((0, t)), ylim=((-1.5, 1.5))) ax.legend(loc='upper left') sns.despine() fig.savefig("plots/integrateNMDAbias2_testWhite%s.pdf" % test) plt.close('all')