def run(NPre=300, N=100, t=20, tTrans=2, nTrain=1, nEnc=10, nTest=10, neuron_type=LIF(), dt=0.001, f=DoubleExp(1e-3, 1e-1), fS=DoubleExp(1e-3, 1e-1), freq=1, muFreq=1.0, sigmaFreq=0.1, reg=1e-1, tauRiseMax=5e-2, tDrive=0.2, base=False, load=False, file=None): print('\nNeuron Type: %s'%neuron_type) rng = np.random.RandomState(seed=0) if load: d1 = np.load(file)['d1'] tauRise1 = np.load(file)['tauRise1'] tauFall1 = np.load(file)['tauFall1'] f1 = DoubleExp(tauRise1, tauFall1) else: print('readout decoders for pre') spikes = np.zeros((nTrain, int(t/0.001), NPre)) targets = np.zeros((nTrain, int(t/0.001), 2)) for n in range(nTrain): data = go(NPre=NPre, N=N, t=t, dt=0.001, f=f, fS=fS, neuron_type=LIF(), freq=freq, phase=2*np.pi*(n/nTrain)) spikes[n] = data['pre'] targets[n] = f.filt(data['inpt'], dt=0.001) d1, f1, tauRise1, tauFall1, X, Y, error = decode(spikes, targets, nTrain, dt=0.001, tauRiseMax=tauRiseMax, name="oscillateNew") np.savez("data/oscillateNew_%s.npz"%neuron_type, d1=d1, tauRise1=tauRise1, tauFall1=tauFall1) times = np.arange(0, t*nTrain, 0.001) plotState(times, X, Y, error, "oscillateNew", "%s_pre"%neuron_type, t*nTrain) if load: e1 = np.load(file)['e1'] elif isinstance(neuron_type, Bio): print("ens1 encoders") e1 = np.zeros((NPre, N, 2)) for n in range(nEnc): data = go(d1=d1, e1=e1, f1=f1, NPre=NPre, N=N, t=t, dt=dt, f=f, fS=fS, neuron_type=neuron_type, freq=freq, phase=2*np.pi*(n/nEnc), l1=True) e1 = data['e1'] np.savez("data/oscillateNew_%s.npz"%neuron_type, d1=d1, tauRise1=tauRise1, tauFall1=tauFall1, e1=e1) plotActivity(t, dt, fS, data['times'], data['ens'], data['tarEns'], "oscillateNew", "ens") else: e1 = np.zeros((NPre, N, 2)) if load: d2 = np.load(file)['d2'] tauRise2 = np.load(file)['tauRise2'] tauFall2 = np.load(file)['tauFall2'] f2 = DoubleExp(tauRise2, tauFall2) else: print('readout decoders for ens') spikes = np.zeros((nTrain, int((t-tTrans)/dt), N)) targets = np.zeros((nTrain, int((t-tTrans)/dt), 2)) for n in range(nTrain): data = go(d1=d1, e1=e1, f1=f1, NPre=NPre, N=N, t=t, dt=dt, f=f, fS=fS, neuron_type=neuron_type, freq=freq, phase=2*np.pi*(n/nTrain)) spikes[n] = data['ens'][int(tTrans/dt):] targets[n] = f.filt(f.filt(data['tar'], dt=dt), dt=dt)[int(tTrans/dt):] d2, f2, tauRise2, tauFall2, X, Y, error = decode(spikes, targets, nTrain, dt=dt, name="oscillateNew", reg=reg, tauRiseMax=tauRiseMax) np.savez("data/oscillateNew_%s.npz"%neuron_type, d1=d1, tauRise1=tauRise1, tauFall1=tauFall1, e1=e1, d2=d2, tauRise2=tauRise2, tauFall2=tauFall2) times = np.arange(0, t*nTrain, dt)[:len(X)] plotState(times, X, Y, error, "oscillateNew", "%s_ens"%neuron_type, (t-tTrans)*nTrain) if load: e2 = np.load(file)['e2'] #elif isinstance(neuron_type, Bio): print("ens2 encoders") #e2 = np.zeros((N, N, 2)) for n in range(nEnc): data = go(d1=d1, e1=e1, f1=f1, d2=d2, e2=e2, NPre=NPre, N=N, t=t, dt=dt, f=f, fS=fS, neuron_type=neuron_type, freq=freq, phase=2*np.pi*(n/nEnc), l2=True) e2 = data['e2'] np.savez("data/oscillateNew_%s.npz"%neuron_type, d1=d1, tauRise1=tauRise1, tauFall1=tauFall1, e1=e1, d2=d2, tauRise2=tauRise2, tauFall2=tauFall2, e2=e2) plotActivity(t, dt, fS, data['times'], data['ens2'], data['tarEns2'], "oscillateNew", "ens2") else: e2 = np.zeros((N, N, 2)) np.savez("data/oscillateNew_%s.npz"%neuron_type, d1=d1, tauRise1=tauRise1, tauFall1=tauFall1, e1=e1, d2=d2, tauRise2=tauRise2, tauFall2=tauFall2, e2=e2) print("testing") errors = np.zeros((nTest)) for test in range(nTest): data = go(d1=d1, e1=e1, f1=f1, d2=d2, e2=e2, f2=f2, NPre=NPre, N=N, t=t+tTrans, dt=dt, f=f, fS=fS, neuron_type=neuron_type, freq=freq, phase=2*np.pi*(test/nTest), tDrive=tDrive, test=True) # curve fit to a sinusoid of arbitrary frequency, phase, magnitude times = data['times'] A = f2.filt(data['ens'], dt=dt) X = np.dot(A, d2) freq0, phase0, mag0, base0 = fitSinusoid(times, X[:,0], freq, int(tTrans/dt), muFreq=muFreq, sigmaFreq=sigmaFreq, base=base) freq1, phase1, mag1, base1 = fitSinusoid(times, X[:,1], freq, int(tTrans/dt), muFreq=muFreq, sigmaFreq=sigmaFreq, base=base) s0 = base0+mag0*np.sin(times*2*np.pi*freq0+phase0) s1 = base1+mag1*np.sin(times*2*np.pi*freq1+phase1) # freqError0 = np.abs(freq-freq0) # freqError1 = np.abs(freq-freq1) rmseError0 = rmse(X[int(tTrans/dt):,0], s0[int(tTrans/dt):]) rmseError1 = rmse(X[int(tTrans/dt):,1], s1[int(tTrans/dt):]) # error0 = (1+freqError0) * rmseError0 # error1 = (1+freqError1) * rmseError1 error0 = rmseError0 error1 = rmseError1 errors[test] = (error0 + error1)/2 fig, (ax, ax2) = plt.subplots(nrows=2, ncols=1, sharex=True) ax.plot(times, X[:,0], label="estimate (dim 0)") ax.plot(times, s0, label="target (dim 0)") ax.set(ylabel='state', title='freq=%.3f, rmse=%.3f'%(freq0, error0), xlim=((0, t)), ylim=((-1.2, 1.2))) ax.legend(loc='upper left') ax2.plot(times, X[:,1], label="estimate (dim 1)") ax2.plot(times, s1, label="target (dim 1)") ax2.set(xlabel='time', ylabel='state', title='freq=%.3f, rmse=%.3f'%(freq1, error1), xlim=((0, t)), ylim=((-1.2, 1.2))) ax2.legend(loc='upper left') sns.despine() fig.savefig("plots/oscillateNew_%s_test%s.pdf"%(neuron_type, test)) plt.close('all') print('%s errors:'%neuron_type, errors) np.savez("data/oscillateNew_%s.npz"%neuron_type, d1=d1, tauRise1=tauRise1, tauFall1=tauFall1, e1=e1, d2=d2, tauRise2=tauRise2, tauFall2=tauFall2, e2=e2, errors=errors) return errors
def run(NPre=100, N=30, t=10, nTrain=5, nEnc=5, nTest=10, dt=0.001, neuron_type=LIF(), fPre=DoubleExp(1e-3, 1e-1), fS=DoubleExp(2e-2, 2e-1), Tff=0.3, reg=1e-1, tauRiseMax=1e-1, tauFallMax=3e-1, load=[], file="data/integrate"): print('\nNeuron Type: %s' % neuron_type) file = file + f"{neuron_type}.npz" if 0 in load: dPreA = np.load(file)['dPreA'] # fix indexing of decoders dPreB = np.load(file)['dPreB'] else: print('readout decoders for preInptA and preInptB') spikesInptA = np.zeros((nTrain, int(t / 0.001), NPre)) spikesInptB = np.zeros((nTrain, int(t / 0.001), NPre)) targetsInptA = np.zeros((nTrain, int(t / 0.001), 1)) targetsInptB = np.zeros((nTrain, int(t / 0.001), 1)) for n in range(nTrain): stimA, stimB = makeSignal(t, fPre, dt=0.001, seed=n) data = go(NPre=NPre, N=N, t=t, dt=0.001, neuron_type=neuron_type, fPre=fPre, fS=fS, stimA=stimA, stimB=stimB) spikesInptA[n] = data['preInptA'] spikesInptB[n] = data['preInptB'] targetsInptA[n] = fPre.filt(Tff * data['inptA'], dt=0.001) targetsInptB[n] = fPre.filt(data['inptB'], dt=0.001) dPreA, X1a, Y1a, error1a = decodeNoF(spikesInptA, targetsInptA, nTrain, fPre, dt=0.001, reg=reg) dPreB, X1b, Y1b, error1b = decodeNoF(spikesInptB, targetsInptB, nTrain, fPre, dt=0.001, reg=reg) np.savez(file, dPreA=dPreA, dPreB=dPreB) times = np.arange(0, t * nTrain, 0.001) plotState(times, X1a, Y1a, error1a, "integrate", "%s_preInptA" % neuron_type, t * nTrain) plotState(times, X1b, Y1b, error1b, "integrate", "%s_preInptB" % neuron_type, t * nTrain) if 1 in load: ePreB = np.load(file)['ePreB'] elif isinstance(neuron_type, Bio): print("encoders for preInptB-to-ens") ePreB = np.zeros((NPre, N, 1)) for n in range(nEnc): stimA, stimB = makeSignal(t, fPre, dt=dt, seed=n) data = go(dPreA=dPreA, dPreB=dPreB, ePreB=ePreB, NPre=NPre, N=N, t=t, dt=dt, neuron_type=neuron_type, fPre=fPre, fS=fS, stimA=stimA, stimB=stimB, stage=1) ePreB = data['ePreB'] plotActivity(t, dt, fS, data['times'], data['ens'], data['tarEns'], "integrate", "preIntgToEns") np.savez(file, dPreA=dPreA, dPreB=dPreB, ePreB=ePreB) else: ePreB = np.zeros((NPre, N, 1)) if 2 in load: ePreA = np.load(file)['ePreA'] elif isinstance(neuron_type, Bio): print("encoders for preInptA-to-ens") ePreA = np.zeros((NPre, N, 1)) for n in range(nEnc): stimA, stimB = makeSignal(t, fPre, dt=dt, seed=n) data = go(dPreA=dPreA, dPreB=dPreB, ePreA=ePreA, ePreB=ePreB, fPre=fPre, NPre=NPre, N=N, t=t, dt=dt, neuron_type=neuron_type, fS=fS, stimA=stimA, stimB=stimB, stage=2) ePreA = data['ePreA'] plotActivity(t, dt, fS, data['times'], data['ens'], data['tarEns'], "integrate", "preInptToEns") np.savez(file, dPreA=dPreA, dPreB=dPreB, ePreA=ePreA, ePreB=ePreB) else: ePreA = np.zeros((NPre, N, 1)) if 3 in load: dEns = np.load(file)['dEns'] tauRiseEns = np.load(file)['tauRiseEns'] tauFallEns = np.load(file)['tauFallEns'] fEns = DoubleExp(tauRiseEns, tauFallEns) else: print('readout decoders for ens') spikes = np.zeros((nTrain, int(t / dt), N)) targets = np.zeros((nTrain, int(t / dt), 1)) for n in range(nTrain): stimA, stimB = makeSignal(t, fPre, dt=dt, seed=n) data = go(dPreA=dPreA, dPreB=dPreB, ePreA=ePreA, ePreB=ePreB, NPre=NPre, N=N, t=t, dt=dt, neuron_type=neuron_type, fPre=fPre, fS=fS, stimA=stimA, stimB=stimB, stage=3) spikes[n] = data['ens'] targets[n] = fPre.filt(data['inptB'], dt=dt) dEns, fEns, tauRiseEns, tauFallEns, X2, Y2, error2 = decode( spikes, targets, nTrain, dt=dt, reg=reg, tauRiseMax=tauRiseMax, tauFallMax=tauFallMax, name="integrate") np.savez(file, dPreA=dPreA, dPreB=dPreB, ePreA=ePreA, ePreB=ePreB, dEns=dEns, tauRiseEns=tauRiseEns, tauFallEns=tauFallEns) times = np.arange(0, t * nTrain, dt) plotState(times, X2, Y2, error2, "integrate", "%s_ens" % neuron_type, t * nTrain) if 4 in load: eBio = np.load(file)['eBio'] elif isinstance(neuron_type, Bio): print("encoders from ens2 to ens") eBio = np.zeros((N, N, 1)) for n in range(nEnc): stimA, stimB = makeSignal(t, fPre, dt=dt, seed=n) data = go(dPreA=dPreA, dPreB=dPreB, dEns=dEns, ePreA=ePreA, ePreB=ePreB, eBio=eBio, NPre=NPre, N=N, t=t, dt=dt, neuron_type=neuron_type, fPre=fPre, fEns=fEns, fS=fS, stimA=stimA, stimB=stimB, stage=4) eBio = data['eBio'] plotActivity(t, dt, fS, data['times'], data['ens'], data['tarEns'], "integrate", "Ens2Ens") np.savez(file, dPreA=dPreA, dPreB=dPreB, ePreA=ePreA, ePreB=ePreB, dEns=dEns, tauRiseEns=tauRiseEns, tauFallEns=tauFallEns, eBio=eBio) else: eBio = np.zeros((N, N, 1)) print("testing") errors = np.zeros((nTest)) for test in range(nTest): stimA, stimB = makeSignal(t, fPre, dt=dt, seed=200 + test) data = go(dPreA=dPreA, dEns=dEns, ePreA=ePreA, eBio=eBio, NPre=NPre, N=N, t=t, dt=dt, neuron_type=neuron_type, fPre=fPre, fEns=fEns, fS=fS, stimA=stimA, stimB=stimB, stage=5) A = fEns.filt(data['ens'], dt=dt) X = np.dot(A, dEns) Y = fPre.filt(data['inptB'], dt=dt) U = fPre.filt(Tff * data['inptA'], dt=dt) error = rmse(X, Y) errorU = rmse(X, U) errors[test] = error fig, ax = plt.subplots() # ax.plot(data['times'], U, label="input") ax.plot(data['times'], X, label="estimate") ax.plot(data['times'], Y, label="target") ax.set(xlabel='time', ylabel='state', title='rmse=%.3f' % error, xlim=((0, t)), ylim=((-1, 1))) ax.legend(loc='upper left') sns.despine() fig.savefig("plots/integrate_%s_test%s.pdf" % (neuron_type, test)) plt.close('all') # print('errors:', errors) # np.savez(file, errors=errors, dPreA=dPreA, dPreB=dPreB, ePreA=ePreA, ePreB=ePreB, dEns=dEns, tauRiseEns=tauRiseEns, tauFallEns=tauFallEns, eBio=eBio) return data['times'], X, Y
def run(NPre=200, N=100, N2=100, t=10, nTrain=10, nEnc=20, nTest=10, neuron_type=LIF(), dt=0.001, f=DoubleExp(1e-3, 1e-1), fS=DoubleExp(1e-3, 1e-1), tauRiseMax=1e-2, load=False, file=None): print('\nNeuron Type: %s' % neuron_type) if load: d1 = np.load(file)['d1'] tauRise1 = np.load(file)['tauRise1'] tauFall1 = np.load(file)['tauFall1'] f1 = DoubleExp(tauRise1, tauFall1) else: print('readout decoders for pre') spikes = np.zeros((nTrain, int(t / 0.001), NPre)) targets = np.zeros((nTrain, int(t / 0.001), 2)) for n in range(nTrain): stim = makeSignal(t, f, dt=0.001, seed=n) data = go(NPre=NPre, N=N, N2=N2, t=t, dt=0.001, f=f, fS=fS, neuron_type=LIF(), stim=stim) spikes[n] = data['pre'] targets[n] = f.filt(data['inpt'], dt=0.001) d1, f1, tauRise1, tauFall1, X, Y, error = decode(spikes, targets, nTrain, dt=0.001, tauRiseMax=tauRiseMax, name="multiplyNew") np.savez("data/multiplyNew_%s.npz" % neuron_type, d1=d1, tauRise1=tauRise1, tauFall1=tauFall1) times = np.arange(0, t * nTrain, 0.001) plotState(times, X, Y, error, "multiplyNew", "%s_pre" % neuron_type, t * nTrain) if load: e1 = np.load(file)['e1'] elif isinstance(neuron_type, Bio): print("ens1 encoders") e1 = np.zeros((NPre, N, 2)) for n in range(nEnc): stim = makeSignal(t, f, dt=dt, seed=n) data = go(d1=d1, e1=e1, f1=f1, NPre=NPre, N=N, N2=N2, t=t, dt=dt, f=f, fS=fS, neuron_type=neuron_type, stim=stim, l1=True) e1 = data['e1'] np.savez("data/multiplyNew_%s.npz" % neuron_type, d1=d1, tauRise1=tauRise1, tauFall1=tauFall1, e1=e1) plotActivity(t, dt, fS, data['times'], data['ens'], data['tarEns'], "multiplyNew", "ens") else: e1 = np.zeros((NPre, N, 2)) if load: d2 = np.load(file)['d2'] tauRise2 = np.load(file)['tauRise2'] tauFall2 = np.load(file)['tauFall2'] f2 = DoubleExp(tauRise2, tauFall2) else: print('readout decoders for ens') spikes = np.zeros((nTrain, int(t / dt), N)) targets = np.zeros((nTrain, int(t / dt), 1)) for n in range(nTrain): stim = makeSignal(t, f, dt=dt, seed=n) data = go(d1=d1, e1=e1, f1=f1, NPre=NPre, N=N, N2=N2, t=t, dt=dt, f=f, fS=fS, neuron_type=neuron_type, stim=stim) spikes[n] = data['ens'] targets[n] = f.filt(data['tar'][:, 0] * data['tar'][:, 1], dt=dt).reshape(-1, 1) d2, f2, tauRise2, tauFall2, X, Y, error = decode(spikes, targets, nTrain, dt=dt, tauRiseMax=tauRiseMax, name="multiplyNew") np.savez("data/multiplyNew_%s.npz" % neuron_type, d1=d1, tauRise1=tauRise1, tauFall1=tauFall1, e1=e1, d2=d2, tauRise2=tauRise2, tauFall2=tauFall2) times = np.arange(0, t * nTrain, dt) plotState(times, X, Y, error, "multiplyNew", "%s_ens" % neuron_type, t * nTrain) if load: e2 = np.load(file)['e2'] elif isinstance(neuron_type, Bio): print("ens2 encoders") e2 = np.zeros((N, N2, 1)) for n in range(nEnc): stim = makeSignal(t, f, dt=dt, seed=n) data = go(d1=d1, e1=e1, f1=f1, d2=d2, e2=e2, f2=f2, NPre=NPre, N=N, N2=N2, t=t, dt=dt, f=f, fS=fS, neuron_type=neuron_type, stim=stim, l2=True) e2 = data['e2'] np.savez("data/multiplyNew_%s.npz" % neuron_type, d1=d1, tauRise1=tauRise1, tauFall1=tauFall1, e1=e1, d2=d2, tauRise2=tauRise2, tauFall2=tauFall2, e2=e2) plotActivity(t, dt, fS, data['times'], data['ens2'], data['tarEns2'], "multiplyNew", "ens2") else: e2 = np.zeros((N, N2, 1)) if load: d3 = np.load(file)['d3'] tauRise3 = np.load(file)['tauRise3'] tauFall3 = np.load(file)['tauFall3'] f3 = DoubleExp(tauRise3, tauFall3) else: print('readout decoders for ens2') spikes = np.zeros((nTrain, int(t / dt), N2)) targets = np.zeros((nTrain, int(t / dt), 1)) for n in range(nTrain): stim = makeSignal(t, f, dt=dt, seed=n) data = go(d1=d1, e1=e1, f1=f1, d2=d2, e2=e2, f2=f2, NPre=NPre, N=N, N2=N2, t=t, dt=dt, f=f, fS=fS, neuron_type=neuron_type, stim=stim) spikes[n] = data['ens2'] targets[n] = f.filt(data['tar2'], dt=dt) d3, f3, tauRise3, tauFall3, X, Y, error = decode(spikes, targets, nTrain, dt=dt, name="multiplyNew") np.savez("data/multiplyNew_%s.npz" % neuron_type, d1=d1, tauRise1=tauRise1, tauFall1=tauFall1, e1=e1, d2=d2, tauRise2=tauRise2, tauFall2=tauFall2, e2=e2, d3=d3, tauRise3=tauRise3, tauFall3=tauFall3) times = np.arange(0, t * nTrain, dt) plotState(times, X, Y, error, "multiplyNew", "%s_ens2" % neuron_type, t * nTrain) errors = np.zeros((nTest)) print("testing") for test in range(nTest): stim = makeSignal(t, f, dt=dt, seed=100 + test) data = go(d1=d1, e1=e1, f1=f1, d2=d2, e2=e2, f2=f2, NPre=NPre, N=N, N2=N2, t=t, dt=dt, f=f, fS=fS, neuron_type=neuron_type, stim=stim) A = f3.filt(data['ens2'], dt=dt) X = np.dot(A, d3) Y = f.filt(data['tar2'], dt=dt) error = rmse(X, Y) errors[test] = error plotState(data['times'], X, Y, error, "multiplyNew", "%s_test%s" % (neuron_type, test), t) A = f2.filt(data['ens'], dt=dt) X = np.dot(A, d2) Y = f.filt(data['tar'][:, 0] * data['tar'][:, 1], dt=dt).reshape(-1, 1) plotState(data['times'], X, Y, rmse(X, Y), "multiplyNew", "%s_pretest%s" % (neuron_type, test), t) print('%s errors:' % neuron_type, errors) np.savez("data/multiplyNew_%s.npz" % neuron_type, d1=d1, tauRise1=tauRise1, tauFall1=tauFall1, e1=e1, d2=d2, tauRise2=tauRise2, tauFall2=tauFall2, e2=e2, d3=d3, tauRise3=tauRise3, tauFall3=tauFall3, errors=errors) return errors
def run(NPre=100, N=100, t=10, c=None, nTrain=10, nTest=5, dt=0.001, neuron_type=LIF(), fPre=DoubleExp(1e-3, 1e-1), fNMDA=DoubleExp(10.6e-3, 285e-3), reg=1e-2, load=[], file=None, neg=True): if not c: c = t if 0 in load: dPre = np.load(file)['dPre'] else: print('readout decoders for pre') spikesInpt = np.zeros((nTrain, int(t / dt), NPre)) targetsInpt = np.zeros((nTrain, int(t / dt), 1)) for n in range(nTrain): stim = makeSignal(t, fPre, fNMDA, seed=n) data = go(NPre=NPre, N=N, t=t, fPre=fPre, fNMDA=fNMDA, neuron_type=neuron_type, stim=stim) spikesInpt[n] = data['pre'] targetsInpt[n] = fPre.filt(data['inpt'], dt=dt) dPre, X, Y, error = decodeNoF(spikesInpt, targetsInpt, nTrain, fPre, reg=reg) np.savez("data/diffMemory_%s.npz" % neuron_type, dPre=dPre) times = np.arange(0, t * nTrain, 0.001) plotState(times, X, Y, error, "diffMemory_%s" % neuron_type, "pre", t * nTrain) if 1 in load: dDiff = np.load(file)['dDiff'] else: print('readout decoders for diff') spikes = np.zeros((nTrain, int(t / dt), N)) targets = np.zeros((nTrain, int(t / dt), 1)) for n in range(nTrain): stim = makeSignal(t, fPre, fNMDA, seed=n) data = go(NPre=NPre, N=N, t=t, fPre=fPre, fNMDA=fNMDA, neuron_type=neuron_type, stim=stim, dPre=dPre) spikes[n] = data['diff'] targets[n] = fNMDA.filt(fPre.filt(data['inpt'], dt=dt), dt=dt) dDiff, X, Y, error = decodeNoF(spikes, targets, nTrain, fNMDA, reg=reg) np.savez("data/diffMemory_%s.npz" % neuron_type, dPre=dPre, dDiff=dDiff) times = np.arange(0, t * nTrain, dt) plotState(times, X, Y, error, "diffMemory_%s" % neuron_type, "diff", t * nTrain) if 2 in load: dEns = np.load(file)['dEns'] dNeg = np.load(file)['dNeg'] else: print('readout decoders for ens') spikes = np.zeros((nTrain, int(t / dt), N)) targets = np.zeros((nTrain, int(t / dt), 1)) for n in range(nTrain): stim = makeSignal(t, fPre, fNMDA, seed=n) # data = go(NPre=NPre, N=N, t=t, fPre=fPre, fNMDA=fNMDA, neuron_type=neuron_type, stim=stim, dPre=dPre, dDiff=dDiff) data = go(NPre=NPre, N=N, t=t, fPre=fPre, fNMDA=fNMDA, neuron_type=neuron_type, stim=stim, dPre=dPre, dDiff=dDiff, train=True, Tff=0.3) spikes[n] = data['ens'] # targets[n] = fNMDA.filt(fNMDA.filt(fPre.filt(data['inpt'], dt=dt), dt=dt), dt=dt) targets[n] = fNMDA.filt(fPre.filt(data['intg'], dt=dt), dt=dt) dEns, X, Y, error = decodeNoF(spikes, targets, nTrain, fNMDA, reg=reg) dNeg = -np.array(dEns) np.savez("data/diffMemory_%s.npz" % neuron_type, dPre=dPre, dDiff=dDiff, dEns=dEns, dNeg=dNeg) times = np.arange(0, t * nTrain, dt) plotState(times, X, Y, error, "diffMemory_%s" % neuron_type, "ens", t * nTrain) print('testing') vals = np.linspace(-1, 1, nTest) for test in range(nTest): if neg: stim = lambda t: vals[test] if t < c else 0 data = go(NPre=NPre, N=N, t=t, fPre=fPre, fNMDA=fNMDA, neuron_type=neuron_type, stim=stim, dPre=dPre, dDiff=dDiff, dEns=dEns, dNeg=dNeg, c=c) else: stim = makeSignal(t, fPre, fNMDA, seed=test) data = go(NPre=NPre, N=N, t=t, fPre=fPre, fNMDA=fNMDA, neuron_type=neuron_type, stim=stim, dPre=dPre, dDiff=dDiff, dEns=dEns, dNeg=None, c=None, Tff=0.3) aDiff = fNMDA.filt(fPre.filt(data['diff'], dt=dt), dt=dt) aEns = fNMDA.filt(fNMDA.filt(fPre.filt(data['ens'], dt=dt), dt=dt), dt=dt) xhatDiff = np.dot(aDiff, dDiff) xhatEns = np.dot(aEns, dEns) u = fNMDA.filt(fPre.filt(data['inpt'], dt=dt), dt=dt) u2 = fNMDA.filt(fNMDA.filt(fPre.filt(data['inpt'], dt=dt), dt=dt), dt=dt) x = fNMDA.filt(fNMDA.filt(fPre.filt(data['intg'], dt=dt), dt=dt), dt=dt) error = rmse(xhatEns, x) fig, ax = plt.subplots() if neg: ax.plot(data['times'], u2, alpha=0.5, label="input (delayed)") ax.axvline(c, label="cutoff") else: # ax.plot(data['times'], 0.3*u2, alpha=0.5, label="input") ax.plot(data['times'], x, alpha=0.5, label="integral") # ax.plot(data['times'], xhatDiff, label="diff") ax.plot(data['times'], xhatEns, label="ens") ax.set(xlabel='time', ylabel='state', title="rmse=%.3f" % error, xlim=((0, t)), ylim=((-1, 1))) ax.legend(loc='upper left') sns.despine() fig.savefig("plots/diffMemory_%s_test%s.pdf" % (neuron_type, test))
def run(NPre=100, N=100, t=10, nTrain=10, nTest=3, nAttr=5, nEnc=10, dt=0.001, c=None, fPre=DoubleExp(1e-3, 1e-1), fNMDA=DoubleExp(10.6e-3, 285e-3), fGABA=DoubleExp(0.5e-3, 1.5e-3), fS=DoubleExp(2e-2, 1e-1), Tff=0.3, reg=1e-2, load=[], file=None): if not c: c = t if 0 in load: dPreA = np.load(file)['dPreA'] dPreB = np.load(file)['dPreB'] dPreC = np.load(file)['dPreC'] dPreD = np.load(file)['dPreD'] else: print('readout decoders for pre') spikesInptA = np.zeros((nTrain, int(t / dt), NPre)) spikesInptB = np.zeros((nTrain, int(t / dt), NPre)) spikesInptC = np.zeros((nTrain, int(t / dt), NPre)) spikesInptD = np.zeros((nTrain, int(t / dt), NPre)) targetsInptA = np.zeros((nTrain, int(t / dt), 1)) targetsInptB = np.zeros((nTrain, int(t / dt), 1)) targetsInptC = np.zeros((nTrain, int(t / dt), 1)) targetsInptD = np.zeros((nTrain, int(t / dt), 1)) for n in range(nTrain): stimA, _ = makeSignal(t, fPre, fNMDA, nAttr, dt=dt, seed=n) stimB = stimA stimC = lambda t: 0 if t < c else 1 stimD = lambda t: 1e-1 data = go(NPre=NPre, N=N, t=t, dt=dt, stimA=stimA, stimB=stimB, stimC=stimC, stimD=stimD, stage=None) spikesInptA[n] = data['preA'] spikesInptB[n] = data['preB'] spikesInptC[n] = data['preC'] spikesInptD[n] = data['preD'] targetsInptA[n] = fPre.filt(data['inptA'], dt=dt) targetsInptB[n] = fPre.filt(data['inptB'], dt=dt) targetsInptC[n] = fPre.filt(data['inptC'], dt=dt) targetsInptD[n] = fPre.filt(data['inptD'], dt=dt) dPreA, XA, YA, errorA = decodeNoF(spikesInptA, targetsInptA, nTrain, fPre, dt=dt, reg=reg) dPreB, XB, YB, errorB = decodeNoF(spikesInptB, targetsInptB, nTrain, fPre, dt=dt, reg=reg) dPreC, XC, YC, errorC = decodeNoF(spikesInptC, targetsInptC, nTrain, fPre, dt=dt, reg=reg) dPreD, XD, YD, errorD = decodeNoF(spikesInptD, targetsInptD, nTrain, fPre, dt=dt, reg=reg) np.savez("data/integrateAttractors.npz", dPreA=dPreA, dPreB=dPreB, dPreC=dPreC, dPreD=dPreD) times = np.arange(0, t * nTrain, dt) plotState(times, XA, YA, errorA, "integrateAttractors", "preA", t * nTrain) plotState(times, XB, YB, errorB, "integrateAttractors", "preB", t * nTrain) plotState(times, XC, YC, errorC, "integrateAttractors", "preC", t * nTrain) plotState(times, XD, YD, errorD, "integrateAttractors", "preD", t * nTrain) if 0 in load: ePreDEns = np.load(file)['ePreDEns'] else: print("encoders for preD to ens") ePreDEns = np.zeros((NPre, N, 1)) for n in range(nEnc): data = go(NPre=NPre, N=N, t=t, dt=dt, dPreD=dPreD, ePreDEns=ePreDEns, stage=0) ePreDEns = data['ePreDEns'] plotActivity(t, dt, fS, data['times'], data['ens'], data['tarEns'], "integrateAttractors", "preDEns") np.savez("data/integrateAttractors.npz", dPreA=dPreA, dPreB=dPreB, dPreC=dPreC, dPreD=dPreD, ePreDEns=ePreDEns) if 1 in load: ePreAFdfw = np.load(file)['ePreAFdfw'] ePreBEns = np.load(file)['ePreBEns'] ePreCOff = np.load(file)['ePreCOff'] else: print("encoders for [preA, preB, preC] to [fdfw, ens, off]") ePreAFdfw = np.zeros((NPre, N, 1)) ePreBEns = np.zeros((NPre, N, 1)) ePreCOff = np.zeros((NPre, N, 1)) for n in range(nEnc): stimA, _ = makeSignal(t, fPre, fNMDA, nAttr, dt=dt, seed=n) stimB = stimA stimC = lambda t: 0 if t < c else 1 data = go(NPre=NPre, N=N, t=t, dt=dt, stimA=stimA, stimB=stimB, stimC=stimB, dPreA=dPreA, dPreB=dPreB, dPreC=dPreC, dPreD=dPreD, ePreAFdfw=ePreAFdfw, ePreBEns=ePreBEns, ePreCOff=ePreCOff, ePreDEns=ePreDEns, stage=1) ePreAFdfw = data['ePreAFdfw'] ePreBEns = data['ePreBEns'] ePreCOff = data['ePreCOff'] plotActivity(t, dt, fS, data['times'], data['fdfw'], data['tarFdfw'], "integrateAttractors", "preAFdfw") plotActivity(t, dt, fS, data['times'], data['ens'], data['tarEns'], "integrateAttractors", "preBEns") plotActivity(t, dt, fS, data['times'], data['off'], data['tarOff'], "integrateAttractors", "preCOff") np.savez("data/integrateAttractors.npz", dPreA=dPreA, dPreB=dPreB, dPreC=dPreC, dPreD=dPreD, ePreDEns=ePreDEns, ePreAFdfw=ePreAFdfw, ePreBEns=ePreBEns, ePreCOff=ePreCOff) if 2 in load: dFdfw = np.load(file)['dFdfw'] dOff = np.load(file)['dOff'] else: print('readout decoders for fdfw and off') spikesFdfw = np.zeros((nTrain, int(t / dt), N)) spikesOff = np.zeros((nTrain, int(t / dt), N)) targetsFdfw = np.zeros((nTrain, int(t / dt), 1)) targetsOff = np.zeros((nTrain, int(t / dt), 1)) for n in range(nTrain): stimA, stimB = makeSignal(t, fPre, fNMDA, nAttr, dt=dt, seed=n) stimC = lambda t: 0 if t < c else 1 data = go(NPre=NPre, N=N, t=t, dt=dt, stimA=stimA, stimB=stimB, stimC=stimC, dPreA=dPreA, dPreB=dPreB, dPreC=dPreC, dPreD=dPreD, ePreAFdfw=ePreAFdfw, ePreBEns=ePreBEns, ePreCOff=ePreCOff, ePreDEns=ePreDEns, stage=2) spikesFdfw[n] = data['fdfw'] spikesOff[n] = data['off'] targetsFdfw[n] = fNMDA.filt(Tff * fPre.filt(data['inptA'])) # targetsFdfw[n] = fNMDA.filt(fPre.filt(data['inptB'])) targetsOff[n] = fGABA.filt(fPre.filt(data['inptC'])) dFdfw, X1, Y1, error1 = decodeNoF(spikesFdfw, targetsFdfw, nTrain, fNMDA, dt=dt, reg=reg) dOff, X2, Y2, error2 = decodeNoF(spikesOff, targetsOff, nTrain, fGABA, dt=dt, reg=reg) np.savez("data/integrateAttractors.npz", dPreA=dPreA, dPreB=dPreB, dPreC=dPreC, dPreD=dPreD, dFdfw=dFdfw, dOff=dOff, ePreDEns=ePreDEns, ePreAFdfw=ePreAFdfw, ePreBEns=ePreBEns, ePreCOff=ePreCOff) times = np.arange(0, t * nTrain, dt) plotState(times, X1, Y1, error1, "integrateAttractors", "fdfw", t * nTrain) plotState(times, X2, Y2, error2, "integrateAttractors", "off", t * nTrain) if 3 in load: eFdfwEns = np.load(file)['eFdfwEns'] else: print("encoders for fdfw-to-ens") eFdfwEns = np.zeros((N, N, 1)) for n in range(nEnc): stimA, stimB = makeSignal(t, fPre, fNMDA, nAttr, dt=dt, seed=n) data = go(NPre=NPre, N=N, t=t, dt=dt, Tff=Tff, stimA=stimA, stimB=stimB, dPreA=dPreA, dPreB=dPreB, dPreD=dPreD, dFdfw=dFdfw, ePreAFdfw=ePreAFdfw, ePreBEns=ePreBEns, ePreDEns=ePreDEns, eFdfwEns=eFdfwEns, stage=3) eFdfwEns = data['eFdfwEns'] plotActivity(t, dt, fS, data['times'], data['ens'], data['tarEns'], "integrateAttractors", "fdfwEns") np.savez("data/integrateAttractors.npz", dPreA=dPreA, dPreB=dPreB, dPreC=dPreC, dPreD=dPreD, dFdfw=dFdfw, dOff=dOff, ePreDEns=ePreDEns, ePreAFdfw=ePreAFdfw, ePreBEns=ePreBEns, ePreCOff=ePreCOff, eFdfwEns=eFdfwEns) if 4 in load: dEns = np.load(file)['dEns'] else: print('readout decoders for ens') spikes = np.zeros((nTrain, int(t / dt), N)) targets = np.zeros((nTrain, int(t / dt), 1)) for n in range(nTrain): stimA, stimB = makeSignal(t, fPre, fNMDA, nAttr, dt=dt, seed=n) data = go(NPre=NPre, N=N, t=t, dt=dt, stimA=stimA, stimB=stimB, dPreA=dPreA, dPreB=dPreB, dPreD=dPreD, dFdfw=dFdfw, ePreAFdfw=ePreAFdfw, ePreBEns=ePreBEns, ePreDEns=ePreDEns, eFdfwEns=eFdfwEns, stage=4) spikes[n] = data['ens'] # targets[n] = data['inptB'] # stimA rounded to nearest attractor targets[n] = fNMDA.filt(fPre.filt( data['inptB'])) # stimA rounded to nearest attractor dEns, X, Y, error = decodeNoF(spikes, targets, nTrain, fNMDA, dt=dt, reg=reg) np.savez("data/integrateAttractors.npz", dPreA=dPreA, dPreB=dPreB, dPreC=dPreC, dPreD=dPreD, dFdfw=dFdfw, dOff=dOff, dEns=dEns, ePreDEns=ePreDEns, ePreAFdfw=ePreAFdfw, ePreBEns=ePreBEns, ePreCOff=ePreCOff, eFdfwEns=eFdfwEns) times = np.arange(0, t * nTrain, dt) plotState(times, X, Y, error, "integrateAttractors", "ens", t * nTrain) if 5 in load: eEnsEns = np.load(file)['eEnsEns'] else: print("encoders from ens to ens") eEnsEns = np.zeros((N, N, 1)) for n in range(nEnc): stimA, stimB = makeSignal(t, fPre, fNMDA, nAttr, dt=dt, seed=n) data = go(NPre=NPre, N=N, t=t, dt=dt, stimA=stimA, stimB=stimB, dPreA=dPreA, dPreB=dPreB, dPreD=dPreD, dFdfw=dFdfw, dEns=dEns, ePreAFdfw=ePreAFdfw, ePreBEns=ePreBEns, ePreDEns=ePreDEns, eFdfwEns=eFdfwEns, eEnsEns=eEnsEns, stage=5) eEnsEns = data['eEnsEns'] plotActivity(t, dt, fS, data['times'], data['ens'], data['tarEns'], "integrateAttractors", "ensEns") np.savez("data/integrateAttractors.npz", dPreA=dPreA, dPreB=dPreB, dPreC=dPreC, dPreD=dPreD, dFdfw=dFdfw, dOff=dOff, dEns=dEns, ePreDEns=ePreDEns, ePreAFdfw=ePreAFdfw, ePreBEns=ePreBEns, ePreCOff=ePreCOff, eFdfwEns=eFdfwEns, eEnsEns=eEnsEns) print("testing") eOffFdfw = -1e2 * np.ones((N, N, 1)) vals = np.linspace(-1, 1, nTest) att = np.linspace(-1, 1, nAttr) for test in range(nTest): stimA = lambda t: vals[test] if t < c else 0 stimB = lambda t: closestAttractor(vals[test], att) stimC = lambda t: 0 if t < c else 1 DATest = lambda t: 0 data = go(NPre=NPre, N=N, t=t, dt=dt, stimA=stimA, stimB=stimB, stimC=stimC, dPreA=dPreA, dPreC=dPreC, dPreD=dPreD, dFdfw=dFdfw, dEns=dEns, dOff=dOff, ePreAFdfw=ePreAFdfw, ePreCOff=ePreCOff, ePreDEns=ePreDEns, eFdfwEns=eFdfwEns, eEnsEns=eEnsEns, eOffFdfw=eOffFdfw, stage=9, c=c, DA=DATest) aFdfw = fNMDA.filt(data['fdfw']) aEns = fNMDA.filt(data['ens']) aOff = fGABA.filt(data['off']) xhatFdfw = np.dot(aFdfw, dFdfw) / Tff xhatEns = np.dot(aEns, dEns) xhatOff = np.dot(aOff, dOff) xFdfw = fNMDA.filt(fPre.filt(data['inptA'])) # xEns = xFdfw[int(c/dt)] * np.ones_like(data['times']) xEns = fNMDA.filt(fPre.filt(data['inptB'])) # xEns = data['inptB'] errorEns = rmse(xhatEns[int(c / dt):], xEns[int(c / dt):]) fig, ax = plt.subplots() ax.plot(data['times'], xFdfw, alpha=0.5, label="input (filtered)") ax.plot(data['times'], xhatFdfw, alpha=0.5, label="fdfw") ax.plot(data['times'], xEns, label="target") ax.plot(data['times'], xhatEns, alpha=0.5, label="ens, rmse=%.3f" % errorEns) ax.plot(data['times'], xhatOff, alpha=0.5, label="off") ax.axvline(c, label="cutoff") ax.set(xlabel='time', ylabel='state', xlim=((0, t)), ylim=((-1.5, 1.5))) ax.legend(loc='upper left') sns.despine() fig.savefig("plots/integrateAttractors_testFlat%s.pdf" % test)
def run(NPre=100, N=100, t=10, nTrain=10, nTest=3, nEnc=10, dt=0.001, c=None, fPre=DoubleExp(1e-3, 1e-1), fNMDA=DoubleExp(10.6e-3, 285e-3), fGABA=DoubleExp(0.5e-3, 1.5e-3), fS=DoubleExp(1e-3, 1e-1), Tff=0.3, Tneg=-1, reg=1e-2, load=[], file=None, DATest=lambda t: 0): if not c: c = t if 0 in load: dPreA = np.load(file)['dPreA'] dPreB = np.load(file)['dPreB'] dPreC = np.load(file)['dPreC'] else: print('readout decoders for pre[A,B,C]') spikesInptA = np.zeros((nTrain, int(t/dt), NPre)) spikesInptB = np.zeros((nTrain, int(t/dt), NPre)) spikesInptC = np.zeros((nTrain, int(t/dt), NPre)) targetsInptA = np.zeros((nTrain, int(t/dt), 1)) targetsInptB = np.zeros((nTrain, int(t/dt), 1)) targetsInptC = np.zeros((nTrain, int(t/dt), 1)) for n in range(nTrain): stimA, stimB = makeSignal(t, fPre, fNMDA, dt=dt, seed=n) data = go( NPre=NPre, N=N, t=t, dt=dt, stimA=stimA, stimB=stimB, stimC=stimB, stage=0) spikesInptA[n] = data['preA'] spikesInptB[n] = data['preB'] spikesInptC[n] = data['preC'] targetsInptA[n] = fPre.filt(data['inptA'], dt=dt) targetsInptB[n] = fPre.filt(data['inptB'], dt=dt) targetsInptC[n] = fPre.filt(data['inptC'], dt=dt) dPreA, XA, YA, errorA = decodeNoF(spikesInptA, targetsInptA, nTrain, fPre, dt=dt, reg=reg) dPreB, XB, YB, errorB = decodeNoF(spikesInptB, targetsInptB, nTrain, fPre, dt=dt, reg=reg) dPreC, XC, YC, errorC = decodeNoF(spikesInptC, targetsInptC, nTrain, fPre, dt=dt, reg=reg) np.savez("data/integrateNMDA.npz", dPreA=dPreA, dPreB=dPreB, dPreC=dPreC) times = np.arange(0, t*nTrain, dt) plotState(times, XA, YA, errorA, "integrateNMDA", "preA", t*nTrain) plotState(times, XB, YB, errorB, "integrateNMDA", "preB", t*nTrain) plotState(times, XC, YC, errorC, "integrateNMDA", "preC", t*nTrain) if 1 in load: ePreA = np.load(file)['ePreA'] ePreB = np.load(file)['ePreB'] ePreC = np.load(file)['ePreC'] else: print("encoders for pre[A,B,C] to [fdfw, ens, inh]") ePreA = np.zeros((NPre, N, 1)) ePreB = np.zeros((NPre, N, 1)) ePreC = np.zeros((NPre, N, 1)) for n in range(nEnc): stimA, stimB = makeSignal(t, fPre, fNMDA, dt=dt, seed=n) data = go( NPre=NPre, N=N, t=t, dt=dt, stimA=stimA, stimB=stimB, stimC=stimB, dPreA=dPreA, dPreB=dPreB, dPreC=dPreC, ePreA=ePreA, ePreB=ePreB, ePreC=ePreC, stage=1) ePreA = data['ePreA'] ePreB = data['ePreB'] ePreC = data['ePreC'] plotActivity(t, dt, fS, data['times'], data['fdfw'], data['tarFdfw'], "integrateNMDA", "preAFdfw") plotActivity(t, dt, fS, data['times'], data['ens'], data['tarEns'], "integrateNMDA", "preBEns") plotActivity(t, dt, fS, data['times'], data['inh'], data['tarInh'], "integrateNMDA", "preCInh") np.savez("data/integrateNMDA.npz", dPreA=dPreA, dPreB=dPreB, dPreC=dPreC, ePreA=ePreA, ePreB=ePreB, ePreC=ePreC) if 2 in load: dFdfw = np.load(file)['dFdfw'] dInh = np.load(file)['dInh'] else: print('readout decoders for fdfw and inh') spikesFdfw = np.zeros((nTrain, int(t/dt), N)) spikesInh = np.zeros((nTrain, int(t/dt), N)) targetsFdfw = np.zeros((nTrain, int(t/dt), 1)) targetsInh = np.zeros((nTrain, int(t/dt), 1)) for n in range(nTrain): stimA, stimB = makeSignal(t, fPre, fNMDA, dt=dt, seed=n) stimC = lambda x: 0.5+0.5*np.sin(2*np.pi*x/t) data = go( NPre=NPre, N=N, t=t, dt=dt, stimA=stimA, stimB=stimB, stimC=stimC, dPreA=dPreA, dPreB=dPreB, dPreC=dPreC, ePreA=ePreA, ePreB=ePreB, ePreC=ePreC, stage=2) spikesFdfw[n] = data['fdfw'] spikesInh[n] = data['inh'] targetsFdfw[n] = fNMDA.filt(Tff*fPre.filt(data['inptA'], dt=dt)) targetsInh[n] = fGABA.filt(fPre.filt(data['inptC'], dt=dt)) dFdfw, X1, Y1, error1 = decodeNoF(spikesFdfw, targetsFdfw, nTrain, fNMDA, dt=dt, reg=reg) dInh, X2, Y2, error2 = decodeNoF(spikesInh, targetsInh, nTrain, fGABA, dt=dt, reg=reg) np.savez("data/integrateNMDA.npz", dPreA=dPreA, dPreB=dPreB, dPreC=dPreC, dFdfw=dFdfw, dInh=dInh, ePreA=ePreA, ePreB=ePreB, ePreC=ePreC) times = np.arange(0, t*nTrain, dt) plotState(times, X1, Y1, error1, "integrateNMDA", "fdfw", t*nTrain) plotState(times, X2, Y2, error2, "integrateNMDA", "inh", t*nTrain) if 3 in load: eFdfw = np.load(file)['eFdfw'] else: print("encoders for fdfw-to-ens") eFdfw = np.zeros((N, N, 1)) for n in range(nEnc): stimA, stimB = makeSignal(t, fPre, fNMDA, dt=dt, seed=n) data = go( NPre=NPre, N=N, t=t, dt=dt, stimA=stimA, stimB=stimB, Tff=Tff, dPreA=dPreA, dPreB=dPreB, dFdfw=dFdfw, ePreA=ePreA, ePreB=ePreB, eFdfw=eFdfw, stage=3) eFdfw = data['eFdfw'] plotActivity(t, dt, fS, data['times'], data['ens'], data['tarEns'], "integrateNMDA", "fdfwEns") np.savez("data/integrateNMDA.npz", dPreA=dPreA, dPreB=dPreB, dPreC=dPreC, dFdfw=dFdfw, dInh=dInh, ePreA=ePreA, ePreB=ePreB, ePreC=ePreC, eFdfw=eFdfw) if 4 in load: dBio = np.load(file)['dBio'] dNeg = np.load(file)['dNeg'] # dNeg = -np.array(dBio) else: print('readout decoders for ens') spikes = np.zeros((nTrain, int(t/dt), N)) targets = np.zeros((nTrain, int(t/dt), 1)) for n in range(nTrain): stimA, stimB = makeSignal(t, fPre, fNMDA, dt=dt, seed=n) data = go( NPre=NPre, N=N, t=t, dt=dt, stimA=stimA, stimB=stimB, dPreA=dPreA, dPreB=dPreB, dFdfw=dFdfw, ePreA=ePreA, ePreB=ePreB, eFdfw=eFdfw, stage=4) spikes[n] = data['ens'] targets[n] = fNMDA.filt(fPre.filt(data['inptB'])) # targets[n] = fNMDA.filt( # fPre.filt(fNMDA.filt(data['inptB'])) + # fNMDA.filt(Tff*fPre.filt(data['inptA']))) dBio, X, Y, error = decodeNoF(spikes, targets, nTrain, fNMDA, dt=dt, reg=reg) dNeg = -np.array(dBio) np.savez("data/integrateNMDA.npz", dPreA=dPreA, dPreB=dPreB, dPreC=dPreC, dFdfw=dFdfw, dInh=dInh, dBio=dBio, dNeg=dNeg, ePreA=ePreA, ePreB=ePreB, ePreC=ePreC, eFdfw=eFdfw) times = np.arange(0, t*nTrain, dt) plotState(times, X, Y, error, "integrateNMDA", "ens", t*nTrain) if 5 in load: eBio = np.load(file)['eBio'] else: print("encoders from ens to ens") eBio = np.zeros((N, N, 1)) for n in range(nEnc): stimA, stimB = makeSignal(t, fPre, fNMDA, dt=dt, seed=n) data = go( NPre=NPre, N=N, t=t, dt=dt, stimA=stimA, stimB=stimB, dPreA=dPreA, dPreB=dPreB, dFdfw=dFdfw, dBio=dBio, ePreA=ePreA, ePreB=ePreB, eFdfw=eFdfw, eBio=eBio, stage=5) eBio = data['eBio'] plotActivity(t, dt, fS, data['times'], data['ens'], data['tarEns'], "integrateNMDA", "ensEns") np.savez("data/integrateNMDA.npz", dPreA=dPreA, dPreB=dPreB, dPreC=dPreC, dFdfw=dFdfw, dBio=dBio, dInh=dInh, dNeg=dNeg, ePreA=ePreA, ePreB=ePreB, ePreC=ePreC, eFdfw=eFdfw, eBio=eBio) if 6 in load: eNeg = np.load(file)['eNeg'] eInh = np.load(file)['eInh'] else: print("encoders from [ens, inh] to fdfw") eNeg = np.zeros((N, N, 1)) eInh = np.zeros((N, N, 1)) for n in range(nEnc): stimA, stimB = makeSignal(t, fPre, fNMDA, dt=dt, seed=n) stimC = lambda t: 0 if t<c else 1 data = go( NPre=NPre, N=N, t=t, dt=dt, stimA=stimA, stimB=stimB, stimC=stimC, dPreA=dPreA, dPreB=dPreB, dPreC=dPreC, dFdfw=dFdfw, dBio=dBio, dNeg=dNeg, dInh=dInh, ePreA=ePreA, ePreB=ePreB, ePreC=ePreC, eFdfw=eFdfw, eBio=eBio, eNeg=eNeg, eInh=eInh, stage=6) eNeg = data['eNeg'] eInh = data['eInh'] plotActivity(t, dt, fS, data['times'], data['fdfw2'], data['tarFdfw2'], "integrateNMDA", "ensFdfw") plotActivity(t, dt, fS, data['times'], data['fdfw4'], data['tarFdfw4'], "integrateNMDA", "inhFdfw") np.savez("data/integrateNMDA.npz", dPreA=dPreA, dPreB=dPreB, dPreC=dPreC, dFdfw=dFdfw, dBio=dBio, dNeg=dNeg, dInh=dInh, ePreA=ePreA, ePreB=ePreB, ePreC=ePreC, eFdfw=eFdfw, eBio=eBio, eNeg=eNeg, eInh=eInh) # eNeg = None # eInh = None print("testing") vals = np.linspace(-1, 1, nTest) for test in range(nTest): stimA = lambda t: vals[test] if t<c else 0 stimC = lambda t: 0 if t<c else 1 data = go( NPre=NPre, N=N, t=t, dt=dt, stimA=stimA, stimC=stimC, dPreA=dPreA, dPreC=dPreC, dFdfw=dFdfw, dBio=dBio, dNeg=Tneg*dNeg, dInh=dInh, ePreA=ePreA, ePreC=ePreC, eFdfw=eFdfw, eBio=eBio, eNeg=eNeg, eInh=eInh, stage=7, c=c, DA=DATest) aFdfw = fNMDA.filt(data['fdfw'], dt=dt) aBio = fNMDA.filt(data['ens'], dt=dt) aInh = fGABA.filt(data['inh'], dt=dt) xhatFdfw = np.dot(aFdfw, dFdfw)/Tff xhatBio = np.dot(aBio, dBio) xhatInh = np.dot(aInh, dInh) xFdfw = fNMDA.filt(fNMDA.filt(fPre.filt(data['inptA'], dt=dt), dt=dt), dt=dt) xBio = xFdfw[int(c/dt)] * np.ones_like(data['times']) errorBio = rmse(xhatBio[int(c/dt):], xBio[int(c/dt):]) fig, ax = plt.subplots() ax.plot(data['times'], xFdfw, alpha=0.5, label="input (filtered)") ax.plot(data['times'], xhatFdfw, alpha=0.5, label="fdfw") ax.plot(data['times'], xBio, label="target") ax.plot(data['times'], xhatBio, alpha=0.5, label="ens, rmse=%.3f"%errorBio) ax.plot(data['times'], xhatInh, alpha=0.5, label="inh") ax.axvline(c, label="cutoff") ax.set(xlabel='time', ylabel='state', xlim=((0, t)), ylim=((-1.5, 1.5))) ax.legend(loc='upper left') sns.despine() fig.savefig("plots/integrateNMDA_testFlat%s.pdf"%test) stimA, stimB = makeSignal(t, fPre, fNMDA, dt=dt, seed=200+test) stimC = lambda t: 0 data = go(NPre=NPre, N=N, t=t, dt=dt, stimA=stimA, stimB=stimB, stimC=stimC, Tff=Tff, dPreA=dPreA, dPreC=dPreC, dFdfw=dFdfw, dBio=dBio, dNeg=None, dInh=dInh, ePreA=ePreA, ePreC=ePreC, eFdfw=eFdfw, eBio=eBio, eNeg=None, eInh=eInh, stage=7, DA=DATest) aFdfw = fNMDA.filt(data['fdfw'], dt=dt) aBio = fNMDA.filt(data['ens'], dt=dt) xhatFdfw = np.dot(aFdfw, dFdfw) xhatBio = np.dot(aBio, dBio) xhatNeg = np.dot(aBio, dNeg) xFdfw = fNMDA.filt(fPre.filt(data['inptA'], dt=dt), dt=dt) xBio = fNMDA.filt(fPre.filt(data['inptB'], dt=dt), dt=dt) errorBio = rmse(xhatBio, xBio) fig, ax = plt.subplots() ax.plot(data['times'], xhatFdfw, alpha=0.5, label="fdfw") ax.plot(data['times'], fNMDA.filt(Tff*xFdfw, dt=dt), alpha=0.5, label="input (filtered)") ax.plot(data['times'], xBio, label="target") ax.plot(data['times'], xhatBio, alpha=0.5, label="ens, rmse=%.3f"%errorBio) ax.set(xlabel='time', ylabel='state', xlim=((0, t)), ylim=((-1.5, 1.5))) ax.legend(loc='upper left') sns.despine() fig.savefig("plots/integrateNMDA_testWhite%s.pdf"%test) plt.close('all')
def run(NPre=100, N=100, t=10, nTrain=10, nTest=30, nEnc=10, neuron_type=LIF(), dt=0.001, f=DoubleExp(1e-3, 1e-1), fS=DoubleExp(1e-3, 1e-1), Tff=0.1, Tfb=1.0, reg=1e-3, tauRiseMax=5e-2, tauFallMax=3e-1, load=False, file=None): print('\nNeuron Type: %s' % neuron_type) if load: d1a = np.load(file)['d1a'] d1b = np.load(file)['d1b'] tauRise1a = np.load(file)['tauRise1a'] tauRise1b = np.load(file)['tauRise1b'] tauFall1a = np.load(file)['tauFall1a'] tauFall1b = np.load(file)['tauFall1b'] f1a = DoubleExp(tauRise1a, tauFall1a) f1b = DoubleExp(tauRise1b, tauFall1b) else: print('readout decoders for preInpt and preIntg') spikesInpt = np.zeros((nTrain, int(t / 0.001), NPre)) spikesIntg = np.zeros((nTrain, int(t / 0.001), NPre)) targetsInpt = np.zeros((nTrain, int(t / 0.001), 1)) targetsIntg = np.zeros((nTrain, int(t / 0.001), 1)) for n in range(nTrain): stim = makeSignal(t, f, dt=0.001, seed=n) data = go(NPre=NPre, N=N, t=t, dt=0.001, f=f, fS=fS, neuron_type=LIF(), stim=stim) spikesInpt[n] = data['preInpt'] spikesIntg[n] = data['preIntg'] targetsInpt[n] = f.filt(Tff * data['inpt'], dt=0.001) targetsIntg[n] = f.filt(data['intg'], dt=0.001) d1a, f1a, tauRise1a, tauFall1a, X1a, Y1a, error1a = decode( spikesInpt, targetsInpt, nTrain, dt=0.001, reg=reg, name="integrateNew", tauRiseMax=tauRiseMax, tauFallMax=tauFallMax) d1b, f1b, tauRise1b, tauFall1b, X1b, Y1b, error1b = decode( spikesIntg, targetsIntg, nTrain, dt=0.001, reg=reg, name="integrateNew", tauRiseMax=tauRiseMax, tauFallMax=tauFallMax) np.savez("data/integrateNew_%s.npz" % neuron_type, d1a=d1a, d1b=d1b, tauRise1a=tauRise1a, tauRise1b=tauRise1b, tauFall1a=tauFall1a, tauFall1b=tauFall1b) times = np.arange(0, t * nTrain, 0.001) plotState(times, X1a, Y1a, error1a, "integrateNew", "%s_preInpt" % neuron_type, t * nTrain) plotState(times, X1b, Y1b, error1b, "integrateNew", "%s_preIntg" % neuron_type, t * nTrain) if load: e1a = np.load(file)['e1a'] e1b = np.load(file)['e1b'] elif isinstance(neuron_type, Bio): e1a = np.zeros((NPre, N, 1)) e1b = np.zeros((NPre, N, 1)) print("encoders for preIntg-to-ens") for n in range(nEnc): stim = makeSignal(t, f, dt=dt, seed=n) data = go(d1a=d1a, d1b=d1b, e1a=e1a, e1b=e1b, f1a=f1a, f1b=f1b, NPre=NPre, N=N, t=t, dt=dt, f=f, fS=fS, neuron_type=neuron_type, stim=stim, l1b=True) e1b = data['e1b'] plotActivity(t, dt, fS, data['times'], data['ens'], data['tarEns'], "integrateNew", "preIntgToEns") np.savez("data/integrateNew_%s.npz" % neuron_type, d1a=d1a, d1b=d1b, tauRise1a=tauRise1a, tauRise1b=tauRise1b, tauFall1a=tauFall1a, tauFall1b=tauFall1b, e1a=e1a, e1b=e1b) print("encoders for preInpt-to-ens") for n in range(nEnc): stim = makeSignal(t, f, dt=dt, seed=n) # stim2 = makeSignal(t, f, dt=dt, seed=n, value=0.5) # stim2 = makeSignal(t, f, dt=dt, norm='u', freq=0.25, value=0.8, seed=100+n) stim2 = makeSin(t, f, dt=dt, seed=n) data = go(d1a=d1a, d1b=d1b, e1a=e1a, e1b=e1b, f1a=f1a, f1b=f1b, NPre=NPre, N=N, t=t, dt=dt, f=f, fS=fS, neuron_type=neuron_type, stim=stim, l1a=True, stim2=stim2) e1a = data['e1a'] plotActivity(t, dt, fS, data['times'], data['ens'], data['tarEns'], "integrateNew", "preInptToEns") np.savez("data/integrateNew_%s.npz" % neuron_type, d1a=d1a, d1b=d1b, tauRise1a=tauRise1a, tauRise1b=tauRise1b, tauFall1a=tauFall1a, tauFall1b=tauFall1b, e1a=e1a, e1b=e1b) else: e1a = np.zeros((NPre, N, 1)) e1b = np.zeros((NPre, N, 1)) if load: d2 = np.load(file)['d2'] tauRise2 = np.load(file)['tauRise2'] tauFall2 = np.load(file)['tauFall2'] f2 = DoubleExp(tauRise2, tauFall2) else: print('readout decoders for ens') spikes = np.zeros((nTrain, int(t / dt), N)) targets = np.zeros((nTrain, int(t / dt), 1)) for n in range(nTrain): stim = makeSignal(t, f, dt=dt, seed=n) data = go(d1a=d1a, d1b=d1b, e1a=e1a, e1b=e1b, f1a=f1a, f1b=f1b, NPre=NPre, N=N, t=t, dt=dt, f=f, fS=fS, neuron_type=neuron_type, stim=stim, l2=True) spikes[n] = data['ens'] targets[n] = f.filt(Tfb * data['intg'], dt=dt) d2, f2, tauRise2, tauFall2, X, Y, error = decode(spikes, targets, nTrain, dt=dt, reg=reg, tauRiseMax=tauRiseMax, tauFallMax=tauFallMax, name="integrateNew") np.savez("data/integrateNew_%s.npz" % neuron_type, d1a=d1a, d1b=d1b, tauRise1a=tauRise1a, tauRise1b=tauRise1b, tauFall1a=tauFall1a, tauFall1b=tauFall1b, e1a=e1a, e1b=e1b, d2=d2, tauRise2=tauRise2, tauFall2=tauFall2) times = np.arange(0, t * nTrain, dt) plotState(times, X, Y, error, "integrateNew", "%s_ens" % neuron_type, t * nTrain) if load: e2 = np.load(file)['e2'] elif isinstance(neuron_type, Bio): print("encoders from ens2 to ens") e2 = np.zeros((N, N, 1)) for n in range(nEnc): stim = makeSignal(t, f, dt=dt, seed=n) #stim = makeSin(t, f, dt=dt, seed=n) data = go(d1a=d1a, d1b=d1b, d2=d2, e1a=e1a, e1b=e1b, e2=e2, f1a=f1a, f1b=f1b, f2=f2, NPre=NPre, N=N, t=t, dt=dt, f=f, fS=fS, neuron_type=neuron_type, stim=stim, l3=True) e2 = data['e2'] plotActivity(t, dt, fS, data['times'], data['ens'], data['tarEns2'], "integrateNew", "Ens2Ens") np.savez("data/integrateNew_%s.npz" % neuron_type, d1a=d1a, d1b=d1b, tauRise1a=tauRise1a, tauRise1b=tauRise1b, tauFall1a=tauFall1a, tauFall1b=tauFall1b, e1a=e1a, e1b=e1b, d2=d2, tauRise2=tauRise2, tauFall2=tauFall2, e2=e2) else: e2 = np.zeros((N, N, 1)) print("testing") errors = np.zeros((nTest)) for test in range(nTest): stim = makeSignal(t, f, dt=dt, seed=200 + test) data = go(d1a=d1a, d2=d2, e1a=e1a, e2=e2, f1a=f1a, f2=f2, NPre=NPre, N=N, t=t, dt=dt, f=f, fS=fS, neuron_type=neuron_type, stim=stim, test=True) A = f2.filt(data['ens'], dt=dt) X = np.dot(A, d2) Y = f.filt(data['intg'], dt=dt) U = f.filt(f.filt(Tff * data['inpt'], dt=dt)) error = rmse(X, Y) errorU = rmse(X, U) errors[test] = error plotState(data['times'], X, Y, error, "integrateNew", "%s_test%s" % (neuron_type, test), t) #plotState(data['times'], X, U, errorU, "integrateNew", "%s_inpt%s"%(neuron_type, test), t) print('%s errors:' % neuron_type, errors) np.savez("data/integrateNew_%s.npz" % neuron_type, d1a=d1a, d1b=d1b, tauRise1a=tauRise1a, tauRise1b=tauRise1b, tauFall1a=tauFall1a, tauFall1b=tauFall1b, e1a=e1a, e1b=e1b, d2=d2, tauRise2=tauRise2, tauFall2=tauFall2, e2=e2, errors=errors) return errors
def run(NPre=300, NBias=100, N=30, t=10, nTrain=10, nTest=5, nEnc=10, dt=0.001, fPre=DoubleExp(1e-3, 1e-1), fNMDA=DoubleExp(10.6e-3, 285e-3), fGABA=DoubleExp(0.5e-3, 1.5e-3), fS=DoubleExp(2e-2, 1e-1), DATrain=lambda t: 0, DATest=lambda t: 0, Tff=0.3, reg=1e-2, load=[], file=None): if 0 in load: dPre = np.load(file)['dPre'] else: print('readout decoders for [preA, preB, preC]') spikesInpt = np.zeros((nTrain, int(t / dt), NPre)) targetsInpt = np.zeros((nTrain, int(t / dt), 1)) for n in range(nTrain): stimA, stimB = makeSignal(t, fPre, fNMDA, dt=dt, seed=n) data = go(NPre=NPre, NBias=NBias, N=N, t=t, dt=dt, DA=DATrain, stimA=stimA, stimB=stimB, stage=0) spikesInpt[n] = data['preA'] targetsInpt[n] = fPre.filt(data['inptA'], dt=dt) dPre, X, Y, error = decodeNoF(spikesInpt, targetsInpt, nTrain, fPre, dt=dt, reg=reg) np.savez( "data/integrateNMDAbias2.npz", dPre=dPre, ) times = np.arange(0, t * nTrain, dt) plotState(times, X, Y, error, "integrateNMDAbias2", "pre", t * nTrain) if 1 in load: ePreFdfw = np.load(file)['ePreFdfw'] ePreBias = np.load(file)['ePreBias'] else: print("encoders for [preA, preC] to [fdfw, bias]") ePreFdfw = np.zeros((NPre, N, 1)) ePreBias = np.zeros((NPre, NBias, 1)) for n in range(nEnc): stimA, stimB = makeSignal(t, fPre, fNMDA, dt=dt, seed=n) data = go(NPre=NPre, NBias=NBias, N=N, t=t, dt=dt, DA=DATrain, stimA=stimA, stimB=stimB, dPre=dPre, ePreFdfw=ePreFdfw, ePreBias=ePreBias, stage=1) ePreFdfw = data['ePreFdfw'] ePreBias = data['ePreBias'] plotActivity(t, dt, fS, data['times'], data['fdfw'], data['tarFdfw'], "integrateNMDAbias2", "pre_fdfw") plotActivity(t, dt, fS, data['times'], data['bias'], data['tarBias'], "integrateNMDAbias2", "pre_bias") np.savez("data/integrateNMDAbias2.npz", dPre=dPre, ePreFdfw=ePreFdfw, ePreBias=ePreBias) if 2 in load: dFdfw = np.load(file)['dFdfw'] dBias = np.load(file)['dBias'] else: print('readout decoders for fdfw and bias') spikesFdfw = np.zeros((nTrain, int(t / dt), N)) spikesBias = np.zeros((nTrain, int(t / dt), NBias)) targetsFdfw = np.zeros((nTrain, int(t / dt), 1)) targetsBias = np.zeros((nTrain, int(t / dt), 1)) for n in range(nTrain): stimA, _ = makeSignal(t, fPre, fNMDA, dt=dt, seed=n) data = go(NPre=NPre, NBias=NBias, N=N, t=t, dt=dt, DA=DATrain, stimA=stimA, dPre=dPre, ePreFdfw=ePreFdfw, ePreBias=ePreBias, stage=2) spikesFdfw[n] = data['fdfw'] spikesBias[n] = data['bias'] targetsFdfw[n] = fNMDA.filt(Tff * fPre.filt(data['inptA'], dt=dt)) targetsBias[n] = fGABA.filt(fPre.filt(data['inptC'], dt=dt)) dFdfw, X1, Y1, error1 = decodeNoF(spikesFdfw, targetsFdfw, nTrain, fNMDA, dt=dt, reg=reg) dBias, X2, Y2, error2 = decodeNoF(spikesBias, targetsBias, nTrain, fGABA, dt=dt, reg=reg) np.savez("data/integrateNMDAbias2.npz", dPre=dPre, ePreFdfw=ePreFdfw, ePreBias=ePreBias, dFdfw=dFdfw, dBias=dBias) times = np.arange(0, t * nTrain, dt) plotState(times, X1, Y1, error1, "integrateNMDAbias2", "fdfw", t * nTrain) plotState(times, X2, Y2, error2, "integrateNMDAbias2", "bias", t * nTrain) if 3 in load: eBiasEns = np.load(file)['eBiasEns'] else: print("encoders for [bias] to ens") eBiasEns = np.zeros((NBias, N, 1)) for n in range(nEnc): stimA, stimB = makeSignal(t, fPre, fNMDA, dt=dt, seed=n) data = go(NPre=NPre, NBias=NBias, N=N, t=t, dt=dt, DA=DATrain, stimA=stimA, dPre=dPre, dBias=dBias, ePreBias=ePreBias, eBiasEns=eBiasEns, stage=3) eBiasEns = data['eBiasEns'] plotActivity(t, dt, fS, data['times'], data['ens'], data['tarEns'], "integrateNMDAbias2", "bias_ens") np.savez( "data/integrateNMDAbias2.npz", dPre=dPre, ePreFdfw=ePreFdfw, ePreBias=ePreBias, dFdfw=dFdfw, dBias=dBias, eBiasEns=eBiasEns, ) if 4 in load: ePreEns = np.load(file)['ePreEns'] else: print("encoders for [preB] to [ens]") ePreEns = np.zeros((NPre, N, 1)) for n in range(nEnc): stimA, stimB = makeSignal(t, fPre, fNMDA, dt=dt, seed=n) data = go(NPre=NPre, NBias=NBias, N=N, t=t, dt=dt, DA=DATrain, stimA=stimA, stimB=stimB, dPre=dPre, dBias=dBias, ePreBias=ePreBias, eBiasEns=eBiasEns, ePreEns=ePreEns, stage=4) ePreEns = data['ePreEns'] plotActivity(t, dt, fS, data['times'], data['ens'], data['tarEns'], "integrateNMDAbias2", "pre_ens") np.savez( "data/integrateNMDAbias2.npz", dPre=dPre, ePreFdfw=ePreFdfw, ePreBias=ePreBias, dFdfw=dFdfw, dBias=dBias, eBiasEns=eBiasEns, ePreEns=ePreEns, ) if 5 in load: eFdfwEns = np.load(file)['eFdfwEns'] else: print("encoders for [fdfw] to ens") eFdfwEns = np.zeros((N, N, 1)) for n in range(nEnc): # stimA, stimB = makeSignal(t, fPre, fNMDA, dt=dt, seed=n) stimA, stimB = makeCarrier(t, fPre, fNMDA, dt=dt, seed=n) data = go(NPre=NPre, NBias=NBias, N=N, t=t, dt=dt, DA=DATrain, stimA=stimA, Tff=Tff, dPre=dPre, dFdfw=dFdfw, dBias=dBias, ePreFdfw=ePreFdfw, ePreEns=ePreEns, eFdfwEns=eFdfwEns, eBiasEns=eBiasEns, stage=5) eFdfwEns = data['eFdfwEns'] plotActivity(t, dt, fS, data['times'], data['ens'], data['tarEns'], "integrateNMDAbias2", "fdfw_ens") np.savez( "data/integrateNMDAbias2.npz", dPre=dPre, ePreFdfw=ePreFdfw, ePreEns=ePreEns, ePreBias=ePreBias, dFdfw=dFdfw, dBias=dBias, eFdfwEns=eFdfwEns, eBiasEns=eBiasEns, ) if 6 in load: dEns = np.load(file)['dEns'] else: print('readout decoders for ens') spikes = np.zeros((nTrain, int(t / dt), N)) targets = np.zeros((nTrain, int(t / dt), 1)) for n in range(nTrain): # stimA, stimB = makeSignal(t, fPre, fNMDA, dt=dt, seed=n) stimA, stimB = makeCarrier(t, fPre, fNMDA, dt=dt, seed=n) data = go(NPre=NPre, NBias=NBias, N=N, t=t, dt=dt, DA=DATrain, stimA=stimA, stimB=stimB, dPre=dPre, dFdfw=dFdfw, dBias=dBias, ePreFdfw=ePreFdfw, ePreEns=ePreEns, eFdfwEns=eFdfwEns, eBiasEns=eBiasEns, stage=6) spikes[n] = data['ens'] targets[n] = fNMDA.filt(fPre.filt(data['inptB'])) dEns, X, Y, error = decodeNoF(spikes, targets, nTrain, fNMDA, dt=dt, reg=reg) np.savez("data/integrateNMDAbias2.npz", dPre=dPre, ePreFdfw=ePreFdfw, ePreEns=ePreEns, ePreBias=ePreBias, dFdfw=dFdfw, dBias=dBias, eFdfwEns=eFdfwEns, eBiasEns=eBiasEns, dEns=dEns) times = np.arange(0, t * nTrain, dt) plotState(times, X, Y, error, "integrateNMDAbias2", "ens", t * nTrain) if 7 in load: eEnsEns = np.load(file)['eEnsEns'] else: print("encoders from ens to ens") eEnsEns = np.zeros((N, N, 1)) for n in range(nEnc): # stimA, stimB = makeSignal(t, fPre, fNMDA, dt=dt, seed=n) stimA, stimB = makeCarrier(t, fPre, fNMDA, dt=dt, seed=n) data = go(NPre=NPre, NBias=NBias, N=N, t=t, dt=dt, DA=DATrain, stimA=stimA, stimB=stimB, dPre=dPre, dFdfw=dFdfw, dBias=dBias, dEns=dEns, ePreFdfw=ePreFdfw, ePreEns=ePreEns, eFdfwEns=eFdfwEns, eBiasEns=eBiasEns, eEnsEns=eEnsEns, stage=7) eEnsEns = data['eEnsEns'] plotActivity(t, dt, fS, data['times'], data['ens'], data['tarEns'], "integrateNMDAbias2", "ens_ens") np.savez("data/integrateNMDAbias2.npz", dPre=dPre, ePreFdfw=ePreFdfw, ePreEns=ePreEns, ePreBias=ePreBias, dFdfw=dFdfw, dBias=dBias, eFdfwEns=eFdfwEns, eBiasEns=eBiasEns, dEns=dEns, eEnsEns=eEnsEns) print("testing") vals = np.linspace(-1, 1, nTest) for test in range(nTest): stimA, stimB = makeSignal(t, fPre, fNMDA, dt=dt, seed=200 + test) data = go(NPre=NPre, NBias=NBias, N=N, t=t, dt=dt, DA=DATest, stimA=stimA, stimB=stimB, Tff=Tff, dPre=dPre, dFdfw=dFdfw, dBias=dBias, dEns=dEns, ePreFdfw=ePreFdfw, ePreEns=ePreEns, eFdfwEns=eFdfwEns, eBiasEns=eBiasEns, eEnsEns=eEnsEns, stage=8) aFdfw = fNMDA.filt(data['fdfw'], dt=dt) aBio = fNMDA.filt(data['ens'], dt=dt) xhatFdfw = np.dot(aFdfw, dFdfw) xhatEns = np.dot(aBio, dEns) xFdfw = fNMDA.filt(fPre.filt(data['inptA'], dt=dt), dt=dt) xEns = fNMDA.filt(fPre.filt(data['inptB'], dt=dt), dt=dt) errorBio = rmse(xhatEns, xEns) fig, ax = plt.subplots() ax.plot(data['times'], xhatFdfw, alpha=0.5, label="fdfw") ax.plot(data['times'], Tff * xFdfw, alpha=0.5, label="input") ax.plot(data['times'], xEns, label="target") ax.plot(data['times'], xhatEns, alpha=0.5, label="ens, rmse=%.3f" % errorBio) ax.set(xlabel='time', ylabel='state', xlim=((0, t)), ylim=((-1.5, 1.5))) ax.legend(loc='upper left') sns.despine() fig.savefig("plots/integrateNMDAbias2_testWhite%s.pdf" % test) plt.close('all')