Exemplo n.º 1
0
def df_opt_hx(x, ens, name='default', df_evals=100, seed=0, dt=0.001, dt_sample=0.001,
        tau_rise=1e-3, tau_fall=[3e-2, 3e-1], penalty=0.25, algo=tpe.suggest):  # rand.suggest):#

    np.savez_compressed('data/%s_ens.npz'%name, ens=ens)
    np.savez_compressed('data/%s_x.npz'%name, x=x)
    del(ens)
    del(x)
    
    hyperparams = {}
    hyperparams['name'] = name
    hyperparams['dt'] = dt
    hyperparams['dt_sample'] = dt_sample
    hyperparams['tau_rise'] = tau_rise
    # hyperparams['ens'] = hp.loguniform('ens', np.log10(tau_fall[0]), np.log10(tau_fall[1]))
    # hyperparams['x'] = hp.loguniform('x', np.log10(tau_fall[0]), np.log10(tau_fall[1]))
    hyperparams['ens'] = hp.uniform('ens', tau_fall[0], tau_fall[1])
    hyperparams['x'] = hp.uniform('x', tau_fall[0], tau_fall[1])

    def objective(hyperparams):
        taus_ens = [hyperparams['tau_rise'], hyperparams['ens']]
        taus_x = [hyperparams['tau_rise'], hyperparams['x']]
        h_ens = DoubleExp(taus_ens[0], taus_ens[1])
        h_x = DoubleExp(taus_x[0], taus_x[1])
        A = h_ens.filt(np.load('data/%s_ens.npz'%hyperparams['name'])['ens'], dt=hyperparams['dt_sample'])
        x = h_x.filt(np.load('data/%s_x.npz'%hyperparams['name'])['x'], dt=hyperparams['dt_sample'])
        if dt != dt_sample:
            A = A[::int(dt_sample/dt)]
            x = x[::int(dt_sample/dt)]
        d_ens = Lstsq()(A, x)[0]
        xhat = np.dot(A, d_ens)
        loss = rmse(xhat, x)
        loss += penalty * taus_ens[1]
        return {'loss': loss, 'taus_ens': taus_ens, 'taus_x': taus_x, 'd_ens': d_ens, 'status': STATUS_OK}
    
    trials = Trials()
    fmin(objective,
        rstate=np.random.RandomState(seed=seed),
        space=hyperparams,
        algo=algo,
        max_evals=df_evals,
        trials=trials)
    best_idx = np.argmin(trials.losses())
    best = trials.trials[best_idx]
    taus_ens = best['result']['taus_ens']
    taus_x = best['result']['taus_x']
    d_ens = best['result']['d_ens']
    h_ens = DoubleExp(taus_ens[0], taus_ens[1])
    h_x = DoubleExp(taus_x[0], taus_x[1])
        
    return d_ens, h_ens, taus_ens, h_x, taus_x
Exemplo n.º 2
0
 def objective(hyperparams):
     taus_ens = [hyperparams['tau_rise'], hyperparams['ens']]
     taus_x = [hyperparams['tau_rise'], hyperparams['x']]
     h_ens = DoubleExp(taus_ens[0], taus_ens[1])
     h_x = DoubleExp(taus_x[0], taus_x[1])
     A = h_ens.filt(np.load('data/%s_ens.npz'%hyperparams['name'])['ens'], dt=hyperparams['dt_sample'])
     x = h_x.filt(np.load('data/%s_x.npz'%hyperparams['name'])['x'], dt=hyperparams['dt_sample'])
     if dt != dt_sample:
         A = A[::int(dt_sample/dt)]
         x = x[::int(dt_sample/dt)]
     d_ens = Lstsq()(A, x)[0]
     xhat = np.dot(A, d_ens)
     loss = rmse(xhat, x)
     loss += penalty * taus_ens[1]
     return {'loss': loss, 'taus_ens': taus_ens, 'taus_x': taus_x, 'd_ens': d_ens, 'status': STATUS_OK}
Exemplo n.º 3
0
def decode(spikes,
           targets,
           nTrain,
           dt=0.001,
           dtSample=0.001,
           reg=1e-3,
           penalty=0,
           evals=100,
           name="default",
           tauRiseMax=3e-2,
           tauFallMax=3e-1):
    d, tauRise, tauFall = dfOpt(spikes,
                                targets,
                                nTrain,
                                name=name,
                                evals=evals,
                                reg=reg,
                                penalty=penalty,
                                dt=dt,
                                dtSample=dtSample,
                                tauRiseMax=tauRiseMax,
                                tauFallMax=tauFallMax)
    print("tauRise: %.3f, tauFall: %.3f" % (tauRise, tauFall))
    f = DoubleExp(tauRise, tauFall)
    A = np.zeros((0, spikes.shape[2]))
    Y = np.zeros((0, targets.shape[2]))
    for n in range(nTrain):
        A = np.append(A, f.filt(spikes[n], dt=dt), axis=0)
        Y = np.append(Y, targets[n], axis=0)
    X = np.dot(A, d)
    error = rmse(X, Y)
    d = d.reshape((-1, targets.shape[2]))
    return d, f, tauRise, tauFall, X, Y, error
Exemplo n.º 4
0
 def objective(hyperparams):
     tauRise = hyperparams['tauRise']
     tauFall = hyperparams['tauFall']
     dt = hyperparams['dt']
     dtSample = hyperparams['dtSample']
     f = DoubleExp(tauRise, tauFall)
     spikes = np.load('data/%s_spikes.npz' % hyperparams['name'])['spikes']
     targets = np.load('data/%s_target.npz' % hyperparams['name'])['target']
     A = np.zeros((0, spikes.shape[2]))
     Y = np.zeros((0, targets.shape[2]))
     for n in range(hyperparams['nTrain']):
         A = np.append(A, f.filt(spikes[n], dt=dt), axis=0)
         Y = np.append(Y, targets[n], axis=0)
     if dt != dtSample:
         A = A[::int(dtSample / dt)]
         Y = Y[::int(dtSample / dt)]
     d, _ = LstsqL2(reg=hyperparams['reg'])(A, Y)
     X = np.dot(A, d)
     loss = rmse(X, Y)
     loss += penalty * (10 * tauRise + tauFall)
     return {
         'loss': loss,
         'd': d,
         'tauRise': tauRise,
         'tauFall': tauFall,
         'status': STATUS_OK
     }
Exemplo n.º 5
0
def go(d_ens, f_ens, n_neurons=3000, t=100, L=False, neuron_type=LIF(),
       m=Uniform(30, 40), i=Uniform(-1, 1), r=40, IC=np.array([1,1,1]),
       seed=0, dt=0.001, dt_sample=0.001, f=DoubleExp(1e-3, 1e-1)):

    with nengo.Network(seed=seed) as model:
        # Ensembles
        u = nengo.Node(lambda t: IC*(t<=1.0))
        x = nengo.Ensemble(1, 3, neuron_type=nengo.Direct())
        ens = nengo.Ensemble(n_neurons, 3, max_rates=m, intercepts=i, neuron_type=neuron_type, seed=seed, radius=r)
        dss = nengo.Node(DownsampleNode(size_in=n_neurons, size_out=n_neurons, dt=dt, dt_sample=dt_sample), size_in=n_neurons, size_out=n_neurons)

        # Connections
        nengo.Connection(u, x, synapse=None)
        nengo.Connection(x, x, function=feedback, synapse=~s)
        if L:
            supv = nengo.Ensemble(n_neurons, 3, neuron_type=SpikingRectifiedLinear(), radius=r, seed=seed)
            nengo.Connection(x, supv, synapse=None)
            nengo.Connection(supv, ens, synapse=f, seed=seed)
        else:
            nengo.Connection(ens, ens, synapse=f_ens, solver=NoSolver(d_ens), seed=seed)

        # Probes
        nengo.Connection(ens.neurons, dss, synapse=None)
        p_x = nengo.Probe(x, synapse=None, sample_every=dt_sample)
        p_ens = nengo.Probe(dss, synapse=None, sample_every=dt_sample)

    with nengo.Simulator(model, seed=seed, dt=dt) as sim:
        sim.run(t)

    return dict(
        times=sim.trange(),
        x=sim.data[p_x],
        ens=sim.data[p_ens])
Exemplo n.º 6
0
 def objective(hyperparams):
     taus_ens = [hyperparams['tau_rise'], hyperparams['tau_fall']]
     h_ens = DoubleExp(taus_ens[0], taus_ens[1])
     A = h_ens.filt(np.load('data/%s_ens.npz'%hyperparams['name'])['ens'], dt=hyperparams['dt'])
     x = np.load('data/%s_x.npz'%hyperparams['name'])['x']
     if dt != dt_sample:
         A = A[::int(dt_sample/dt)]
         x = x[::int(dt_sample/dt)]
     if hyperparams['reg']:
         d_ens = LstsqL2(reg=hyperparams['reg'])(A, x)[0]
     else:
         d_ens = Lstsq()(A, x)[0]
     xhat = np.dot(A, d_ens)
     loss = rmse(xhat, x)
     loss += penalty * (10*taus_ens[0] + taus_ens[1])
     return {'loss': loss, 'taus_ens': taus_ens, 'd_ens': d_ens, 'status': STATUS_OK}
Exemplo n.º 7
0
def go(t=10, m=Uniform(30, 30), i=Uniform(0, 0), seed=0, dt=0.001, f=DoubleExp(1e-3, 1e-1), fS=DoubleExp(1e-3, 1e-1), d1=None, f1=None, e1=None, l1=False, stim=lambda t: np.sin(t)):

    if not f1: f1=f
    with nengo.Network(seed=seed) as model:
        # Stimulus and Nodes
        inpt = nengo.Node(stim)
        tar = nengo.Ensemble(1, 1, neuron_type=nengo.Direct())
        pre = nengo.Ensemble(100, 1, max_rates=m, seed=seed, neuron_type=LIF())
        lif = nengo.Ensemble(1, 1, max_rates=m, intercepts=i, encoders=Choice([[1]]), neuron_type=LIF(), seed=seed)
        wilson = nengo.Ensemble(1, 1, max_rates=m, intercepts=i, encoders=Choice([[1]]), neuron_type=Wilson(), seed=seed)
        bio = nengo.Ensemble(1, 1, max_rates=m, intercepts=i, encoders=Choice([[1]]), neuron_type=Bio("Pyramidal"), seed=seed)
        nengo.Connection(inpt, pre, synapse=None, seed=seed)
        cLif = nengo.Connection(pre, lif, synapse=f1, seed=seed, solver=NoSolver(d1))
        cWilson = nengo.Connection(pre, wilson, synapse=f1, seed=seed, solver=NoSolver(d1))
        cBio = nengo.Connection(pre, bio, synapse=f1, seed=seed, solver=NoSolver(d1))
        pInpt = nengo.Probe(inpt, synapse=None)
        pPre = nengo.Probe(pre.neurons, synapse=None)
        pLif = nengo.Probe(lif.neurons, synapse=None)
        pWilson = nengo.Probe(wilson.neurons, synapse=None)
        pBio = nengo.Probe(bio.neurons, synapse=None)
        if l1: learnEncoders(cBio, lif, fS, alpha=3e-7) # Encoder Learning (Bio)

    with nengo.Simulator(model, seed=seed, dt=dt, progress_bar=False) as sim:
        setWeights(cBio, d1, e1)
        neuron.h.init()
        sim.run(t, progress_bar=True)
        reset_neuron(sim, model) 
      
    e1 = cBio.e if l1 else e1

    return dict(
        times=sim.trange(),
        inpt=sim.data[pInpt],
        pre=sim.data[pPre],
        lif=sim.data[pLif],
        wilson=sim.data[pWilson],
        bio=sim.data[pBio],
        e1=e1,
    )
Exemplo n.º 8
0
def go(NPre=100,
       N=100,
       t=10,
       m=Uniform(30, 30),
       i=Uniform(-0.8, 0.8),
       neuron_type=LIF(),
       seed=0,
       dt=0.001,
       fPre=None,
       fEns=None,
       fS=DoubleExp(2e-2, 2e-1),
       dPreA=None,
       dPreB=None,
       dEns=None,
       ePreA=None,
       ePreB=None,
       eBio=None,
       stage=None,
       alpha=1e-6,
       eMax=1e0,
       stimA=lambda t: np.sin(t),
       stimB=lambda t: 0):

    with nengo.Network(seed=seed) as model:
        inptA = nengo.Node(stimA)
        inptB = nengo.Node(stimB)
        preInptA = nengo.Ensemble(NPre, 1, radius=3, max_rates=m, seed=seed)
        preInptB = nengo.Ensemble(NPre, 1, max_rates=m, seed=seed)
        ens = nengo.Ensemble(N,
                             1,
                             max_rates=m,
                             intercepts=i,
                             neuron_type=neuron_type,
                             seed=seed)
        tarEns = nengo.Ensemble(N,
                                1,
                                max_rates=m,
                                intercepts=i,
                                neuron_type=nengo.LIF(),
                                seed=seed)
        cpa = nengo.Connection(inptA, preInptA, synapse=None, seed=seed)
        cpb = nengo.Connection(inptB, preInptB, synapse=None, seed=seed)
        pInptA = nengo.Probe(inptA, synapse=None)
        pInptB = nengo.Probe(inptB, synapse=None)
        pPreInptA = nengo.Probe(preInptA.neurons, synapse=None)
        pPreInptB = nengo.Probe(preInptB.neurons, synapse=None)
        pEns = nengo.Probe(ens.neurons, synapse=None)
        pTarEns = nengo.Probe(tarEns.neurons, synapse=None)
        if stage == 1:
            c0b = nengo.Connection(preInptB,
                                   tarEns,
                                   synapse=fPre,
                                   solver=NoSolver(dPreB),
                                   seed=seed + 1)
            c1b = nengo.Connection(preInptB,
                                   ens,
                                   synapse=fPre,
                                   solver=NoSolver(dPreB),
                                   seed=seed + 1)
            learnEncoders(c1b, tarEns, fS, alpha=alpha, eMax=eMax)
        if stage == 2:
            c0a = nengo.Connection(preInptA,
                                   tarEns,
                                   synapse=fPre,
                                   solver=NoSolver(dPreA),
                                   seed=seed)
            c0b = nengo.Connection(preInptB,
                                   tarEns,
                                   synapse=fPre,
                                   solver=NoSolver(dPreB),
                                   seed=seed + 1)
            c1a = nengo.Connection(preInptA,
                                   ens,
                                   synapse=fPre,
                                   solver=NoSolver(dPreA),
                                   seed=seed)
            c1b = nengo.Connection(preInptB,
                                   ens,
                                   synapse=fPre,
                                   solver=NoSolver(dPreB),
                                   seed=seed + 1)
            learnEncoders(c1a, tarEns, fS, alpha=alpha, eMax=eMax)
        if stage == 3:
            c1a = nengo.Connection(preInptA,
                                   ens,
                                   synapse=fPre,
                                   solver=NoSolver(dPreA),
                                   seed=seed)
            c1b = nengo.Connection(preInptB,
                                   ens,
                                   synapse=fPre,
                                   solver=NoSolver(dPreB),
                                   seed=seed + 1)
        if stage == 4:
            preInptC = nengo.Ensemble(NPre, 1, max_rates=m, seed=seed)
            ens2 = nengo.Ensemble(N,
                                  1,
                                  max_rates=m,
                                  intercepts=i,
                                  neuron_type=neuron_type,
                                  seed=seed)
            ens3 = nengo.Ensemble(N,
                                  1,
                                  max_rates=m,
                                  intercepts=i,
                                  neuron_type=neuron_type,
                                  seed=seed)
            nengo.Connection(inptB, preInptC, synapse=fEns, seed=seed)
            c1a = nengo.Connection(preInptA,
                                   ens,
                                   synapse=fPre,
                                   solver=NoSolver(dPreA),
                                   seed=seed)
            c2b = nengo.Connection(preInptB,
                                   ens2,
                                   synapse=fPre,
                                   solver=NoSolver(dPreB),
                                   seed=seed + 1)
            c3 = nengo.Connection(ens2,
                                  ens,
                                  synapse=fEns,
                                  solver=NoSolver(dEns),
                                  seed=seed)
            c4a = nengo.Connection(preInptA,
                                   ens3,
                                   synapse=fPre,
                                   solver=NoSolver(dPreA),
                                   seed=seed)
            c4b = nengo.Connection(preInptC,
                                   ens3,
                                   synapse=fPre,
                                   solver=NoSolver(dPreB),
                                   seed=seed)
            learnEncoders(c3, ens3, fS, alpha=alpha / 10, eMax=eMax)
            pTarEns = nengo.Probe(ens3.neurons, synapse=None)
        if stage == 5:
            c1a = nengo.Connection(preInptA,
                                   ens,
                                   synapse=fPre,
                                   solver=NoSolver(dPreA),
                                   seed=seed)
            c5 = nengo.Connection(ens,
                                  ens,
                                  synapse=fEns,
                                  solver=NoSolver(dEns),
                                  seed=seed)

    with nengo.Simulator(model, seed=seed, dt=dt, progress_bar=False) as sim:
        if isinstance(neuron_type, Bio):
            if stage == 1:
                setWeights(c1b, dPreB, ePreB)
            if stage == 2:
                setWeights(c1a, dPreA, ePreA)
                setWeights(c1b, dPreB, ePreB)
            if stage == 3:
                setWeights(c1a, dPreA, ePreA)
                setWeights(c1b, dPreB, ePreB)
            if stage == 4:
                setWeights(c1a, dPreA, ePreA)
                setWeights(c2b, dPreB, ePreB)
                setWeights(c4a, dPreA, ePreA)
                setWeights(c4b, dPreB, ePreB)
                setWeights(c3, dEns, eBio)
            if stage == 5:
                setWeights(c1a, dPreA, ePreA)
                setWeights(c5, dEns, eBio)
            neuron.h.init()
            sim.run(t, progress_bar=True)
            reset_neuron(sim, model)
        else:
            sim.run(t, progress_bar=True)

    ePreB = c1b.e if stage == 1 else ePreB
    ePreA = c1a.e if stage == 2 else ePreA
    eBio = c3.e if stage == 4 else eBio

    return dict(
        times=sim.trange(),
        inptA=sim.data[pInptA],
        inptB=sim.data[pInptB],
        preInptA=sim.data[pPreInptA],
        preInptB=sim.data[pPreInptB],
        ens=sim.data[pEns],
        tarEns=sim.data[pTarEns],
        ePreA=ePreA,
        ePreB=ePreB,
        eBio=eBio,
    )
Exemplo n.º 9
0
def run(NPre=100,
        N=30,
        t=10,
        nTrain=5,
        nEnc=5,
        nTest=10,
        dt=0.001,
        neuron_type=LIF(),
        fPre=DoubleExp(1e-3, 1e-1),
        fS=DoubleExp(2e-2, 2e-1),
        Tff=0.3,
        reg=1e-1,
        tauRiseMax=1e-1,
        tauFallMax=3e-1,
        load=[],
        file="data/integrate"):

    print('\nNeuron Type: %s' % neuron_type)
    file = file + f"{neuron_type}.npz"

    if 0 in load:
        dPreA = np.load(file)['dPreA']  # fix indexing of decoders
        dPreB = np.load(file)['dPreB']
    else:
        print('readout decoders for preInptA and preInptB')
        spikesInptA = np.zeros((nTrain, int(t / 0.001), NPre))
        spikesInptB = np.zeros((nTrain, int(t / 0.001), NPre))
        targetsInptA = np.zeros((nTrain, int(t / 0.001), 1))
        targetsInptB = np.zeros((nTrain, int(t / 0.001), 1))
        for n in range(nTrain):
            stimA, stimB = makeSignal(t, fPre, dt=0.001, seed=n)
            data = go(NPre=NPre,
                      N=N,
                      t=t,
                      dt=0.001,
                      neuron_type=neuron_type,
                      fPre=fPre,
                      fS=fS,
                      stimA=stimA,
                      stimB=stimB)
            spikesInptA[n] = data['preInptA']
            spikesInptB[n] = data['preInptB']
            targetsInptA[n] = fPre.filt(Tff * data['inptA'], dt=0.001)
            targetsInptB[n] = fPre.filt(data['inptB'], dt=0.001)
        dPreA, X1a, Y1a, error1a = decodeNoF(spikesInptA,
                                             targetsInptA,
                                             nTrain,
                                             fPre,
                                             dt=0.001,
                                             reg=reg)
        dPreB, X1b, Y1b, error1b = decodeNoF(spikesInptB,
                                             targetsInptB,
                                             nTrain,
                                             fPre,
                                             dt=0.001,
                                             reg=reg)
        np.savez(file, dPreA=dPreA, dPreB=dPreB)
        times = np.arange(0, t * nTrain, 0.001)
        plotState(times, X1a, Y1a, error1a, "integrate",
                  "%s_preInptA" % neuron_type, t * nTrain)
        plotState(times, X1b, Y1b, error1b, "integrate",
                  "%s_preInptB" % neuron_type, t * nTrain)

    if 1 in load:
        ePreB = np.load(file)['ePreB']
    elif isinstance(neuron_type, Bio):
        print("encoders for preInptB-to-ens")
        ePreB = np.zeros((NPre, N, 1))
        for n in range(nEnc):
            stimA, stimB = makeSignal(t, fPre, dt=dt, seed=n)
            data = go(dPreA=dPreA,
                      dPreB=dPreB,
                      ePreB=ePreB,
                      NPre=NPre,
                      N=N,
                      t=t,
                      dt=dt,
                      neuron_type=neuron_type,
                      fPre=fPre,
                      fS=fS,
                      stimA=stimA,
                      stimB=stimB,
                      stage=1)
            ePreB = data['ePreB']
            plotActivity(t, dt, fS, data['times'], data['ens'], data['tarEns'],
                         "integrate", "preIntgToEns")
            np.savez(file, dPreA=dPreA, dPreB=dPreB, ePreB=ePreB)
    else:
        ePreB = np.zeros((NPre, N, 1))

    if 2 in load:
        ePreA = np.load(file)['ePreA']
    elif isinstance(neuron_type, Bio):
        print("encoders for preInptA-to-ens")
        ePreA = np.zeros((NPre, N, 1))
        for n in range(nEnc):
            stimA, stimB = makeSignal(t, fPre, dt=dt, seed=n)
            data = go(dPreA=dPreA,
                      dPreB=dPreB,
                      ePreA=ePreA,
                      ePreB=ePreB,
                      fPre=fPre,
                      NPre=NPre,
                      N=N,
                      t=t,
                      dt=dt,
                      neuron_type=neuron_type,
                      fS=fS,
                      stimA=stimA,
                      stimB=stimB,
                      stage=2)
            ePreA = data['ePreA']
            plotActivity(t, dt, fS, data['times'], data['ens'], data['tarEns'],
                         "integrate", "preInptToEns")
            np.savez(file, dPreA=dPreA, dPreB=dPreB, ePreA=ePreA, ePreB=ePreB)
    else:
        ePreA = np.zeros((NPre, N, 1))

    if 3 in load:
        dEns = np.load(file)['dEns']
        tauRiseEns = np.load(file)['tauRiseEns']
        tauFallEns = np.load(file)['tauFallEns']
        fEns = DoubleExp(tauRiseEns, tauFallEns)
    else:
        print('readout decoders for ens')
        spikes = np.zeros((nTrain, int(t / dt), N))
        targets = np.zeros((nTrain, int(t / dt), 1))
        for n in range(nTrain):
            stimA, stimB = makeSignal(t, fPre, dt=dt, seed=n)
            data = go(dPreA=dPreA,
                      dPreB=dPreB,
                      ePreA=ePreA,
                      ePreB=ePreB,
                      NPre=NPre,
                      N=N,
                      t=t,
                      dt=dt,
                      neuron_type=neuron_type,
                      fPre=fPre,
                      fS=fS,
                      stimA=stimA,
                      stimB=stimB,
                      stage=3)
            spikes[n] = data['ens']
            targets[n] = fPre.filt(data['inptB'], dt=dt)
        dEns, fEns, tauRiseEns, tauFallEns, X2, Y2, error2 = decode(
            spikes,
            targets,
            nTrain,
            dt=dt,
            reg=reg,
            tauRiseMax=tauRiseMax,
            tauFallMax=tauFallMax,
            name="integrate")
        np.savez(file,
                 dPreA=dPreA,
                 dPreB=dPreB,
                 ePreA=ePreA,
                 ePreB=ePreB,
                 dEns=dEns,
                 tauRiseEns=tauRiseEns,
                 tauFallEns=tauFallEns)
        times = np.arange(0, t * nTrain, dt)
        plotState(times, X2, Y2, error2, "integrate", "%s_ens" % neuron_type,
                  t * nTrain)

    if 4 in load:
        eBio = np.load(file)['eBio']
    elif isinstance(neuron_type, Bio):
        print("encoders from ens2 to ens")
        eBio = np.zeros((N, N, 1))
        for n in range(nEnc):
            stimA, stimB = makeSignal(t, fPre, dt=dt, seed=n)
            data = go(dPreA=dPreA,
                      dPreB=dPreB,
                      dEns=dEns,
                      ePreA=ePreA,
                      ePreB=ePreB,
                      eBio=eBio,
                      NPre=NPre,
                      N=N,
                      t=t,
                      dt=dt,
                      neuron_type=neuron_type,
                      fPre=fPre,
                      fEns=fEns,
                      fS=fS,
                      stimA=stimA,
                      stimB=stimB,
                      stage=4)
            eBio = data['eBio']
            plotActivity(t, dt, fS, data['times'], data['ens'], data['tarEns'],
                         "integrate", "Ens2Ens")
            np.savez(file,
                     dPreA=dPreA,
                     dPreB=dPreB,
                     ePreA=ePreA,
                     ePreB=ePreB,
                     dEns=dEns,
                     tauRiseEns=tauRiseEns,
                     tauFallEns=tauFallEns,
                     eBio=eBio)
    else:
        eBio = np.zeros((N, N, 1))

    print("testing")
    errors = np.zeros((nTest))
    for test in range(nTest):
        stimA, stimB = makeSignal(t, fPre, dt=dt, seed=200 + test)
        data = go(dPreA=dPreA,
                  dEns=dEns,
                  ePreA=ePreA,
                  eBio=eBio,
                  NPre=NPre,
                  N=N,
                  t=t,
                  dt=dt,
                  neuron_type=neuron_type,
                  fPre=fPre,
                  fEns=fEns,
                  fS=fS,
                  stimA=stimA,
                  stimB=stimB,
                  stage=5)
        A = fEns.filt(data['ens'], dt=dt)
        X = np.dot(A, dEns)
        Y = fPre.filt(data['inptB'], dt=dt)
        U = fPre.filt(Tff * data['inptA'], dt=dt)
        error = rmse(X, Y)
        errorU = rmse(X, U)
        errors[test] = error
        fig, ax = plt.subplots()
        #         ax.plot(data['times'], U, label="input")
        ax.plot(data['times'], X, label="estimate")
        ax.plot(data['times'], Y, label="target")
        ax.set(xlabel='time',
               ylabel='state',
               title='rmse=%.3f' % error,
               xlim=((0, t)),
               ylim=((-1, 1)))
        ax.legend(loc='upper left')
        sns.despine()
        fig.savefig("plots/integrate_%s_test%s.pdf" % (neuron_type, test))
        plt.close('all')
    # print('errors:', errors)
    # np.savez(file, errors=errors, dPreA=dPreA, dPreB=dPreB, ePreA=ePreA, ePreB=ePreB, dEns=dEns, tauRiseEns=tauRiseEns, tauFallEns=tauFallEns, eBio=eBio)

    return data['times'], X, Y
Exemplo n.º 10
0
def go(NPre=100, N=30, t=10, m=Uniform(30, 30), i=Uniform(-0.8, 0.8), seed=0, dt=0.001, f=DoubleExp(1e-3, 1e-1), fS=DoubleExp(1e-3, 1e-1), neuron_type=LIF(), d1=None, d2=None, f1=None, f2=None, e1=None, e2=None, l1=False, l2=False, test=False, freq=1, phase=0, tDrive=0.2):

    A = [[1, 1e-1*2*np.pi*freq], [-1e-1*2*np.pi*freq, 1]]  # tau*A + I
    if isinstance(neuron_type, Bio) and not f1: f1=DoubleExp(1e-3, 1e-1)
    if isinstance(neuron_type, Bio) and not f2: f2=DoubleExp(1e-3, 1e-1)
    stim = lambda t: [np.sin(2*np.pi*freq*t+phase), np.cos(2*np.pi*freq*t+phase)]

    with nengo.Network(seed=seed) as model:          
        inpt = nengo.Node(stim)
        tar = nengo.Ensemble(1, 2, neuron_type=nengo.Direct())
        pre = nengo.Ensemble(NPre, 2, max_rates=m, neuron_type=nengo.SpikingRectifiedLinear(), radius=2, seed=seed)
        ens = nengo.Ensemble(N, 2, max_rates=m, intercepts=i, neuron_type=neuron_type, radius=2, seed=seed)
        nengo.Connection(inpt, tar, synapse=None, transform=A, seed=seed)
        nengo.Connection(inpt, pre, synapse=None, seed=seed)
        c1 = nengo.Connection(pre, ens, synapse=f1, seed=seed, solver=NoSolver(d1))
        pInpt = nengo.Probe(inpt, synapse=None)
        pTar = nengo.Probe(tar, synapse=None)
        pPre = nengo.Probe(pre.neurons, synapse=None)
        pEns = nengo.Probe(ens.neurons, synapse=None)
        # Encoder Learning (Bio)
        if l1:
            tarEns = nengo.Ensemble(N, 2, max_rates=m, intercepts=i, neuron_type=nengo.LIF(), seed=seed)
            nengo.Connection(inpt, tarEns, synapse=None, seed=seed)
            learnEncoders(c1, tarEns, fS)
            pTarEns = nengo.Probe(tarEns.neurons, synapse=None)
        if l2:
            pre2 = nengo.Ensemble(NPre, 2, max_rates=m, neuron_type=nengo.LIF(), seed=seed, radius=2)
            tarEns2 = nengo.Ensemble(N, 2, max_rates=m, intercepts=i, neuron_type=nengo.LIF(), seed=seed)
            ens2 = nengo.Ensemble(N, 2, max_rates=m, intercepts=i, neuron_type=neuron_type, seed=seed, radius=2)
            
#             ens3 = nengo.Ensemble(N, 2, max_rates=m, intercepts=i, neuron_type=neuron_type, seed=seed, radius=2)
#             nengo.Connection(tar, pre2, synapse=f)
#             c3 = nengo.Connection(ens, ens2, synapse=f2, seed=seed)
#             c4 = nengo.Connection(pre2, ens3, synapse=f1, seed=seed)
#             learnEncoders(c3, ens3, fS)
#             pTarEns2 = nengo.Probe(ens3.neurons, synapse=None)
#             pEns2 = nengo.Probe(ens2.neurons, synapse=None)

            nengo.Connection(inpt, pre2, synapse=f)
            nengo.Connection(pre2, tarEns2, synapse=f, seed=seed)
            c3 = nengo.Connection(ens, ens2, synapse=f2, seed=seed)
            learnEncoders(c3, tarEns2, fS, alpha=3e-7)
            pTarEns2 = nengo.Probe(tarEns2.neurons, synapse=None)
            pEns2 = nengo.Probe(ens2.neurons, synapse=None)
        if test:
            c2 = nengo.Connection(ens, ens, synapse=f2, seed=seed, solver=NoSolver(d2))
            off = nengo.Node(lambda t: 1 if t>tDrive else 0)
            nengo.Connection(off, pre.neurons, synapse=None, transform=-1e4*np.ones((NPre, 1)))

    with nengo.Simulator(model, seed=seed, dt=dt, progress_bar=False) as sim:
        if isinstance(neuron_type, Bio):
            setWeights(c1, d1, e1)
            if l2: setWeights(c3, d2, e2)
#             if l2: setWeights(c4, d1, e1)
            if test: setWeights(c2, d2, e2)
            neuron.h.init()
            sim.run(t, progress_bar=True)
            reset_neuron(sim, model) 
        else:
            sim.run(t, progress_bar=True)
      
    e1 = c1.e if l1 else e1
    e2 = c3.e if l2 else e2

    return dict(
        times=sim.trange(),
        inpt=sim.data[pInpt],
        tar=sim.data[pTar],
        pre=sim.data[pPre],
        ens=sim.data[pEns],
        tarEns=sim.data[pTarEns] if l1 else None,
        tarEns2=sim.data[pTarEns2] if l2 else None,
        ens2=sim.data[pEns2] if l2 else None,
        e1=e1,
        e2=e2,
    )
Exemplo n.º 11
0
def run(NPre=200,
        N=100,
        N2=100,
        t=10,
        nTrain=10,
        nEnc=20,
        nTest=10,
        neuron_type=LIF(),
        dt=0.001,
        f=DoubleExp(1e-3, 1e-1),
        fS=DoubleExp(1e-3, 1e-1),
        tauRiseMax=1e-2,
        load=False,
        file=None):

    print('\nNeuron Type: %s' % neuron_type)
    if load:
        d1 = np.load(file)['d1']
        tauRise1 = np.load(file)['tauRise1']
        tauFall1 = np.load(file)['tauFall1']
        f1 = DoubleExp(tauRise1, tauFall1)
    else:
        print('readout decoders for pre')
        spikes = np.zeros((nTrain, int(t / 0.001), NPre))
        targets = np.zeros((nTrain, int(t / 0.001), 2))
        for n in range(nTrain):
            stim = makeSignal(t, f, dt=0.001, seed=n)
            data = go(NPre=NPre,
                      N=N,
                      N2=N2,
                      t=t,
                      dt=0.001,
                      f=f,
                      fS=fS,
                      neuron_type=LIF(),
                      stim=stim)
            spikes[n] = data['pre']
            targets[n] = f.filt(data['inpt'], dt=0.001)
        d1, f1, tauRise1, tauFall1, X, Y, error = decode(spikes,
                                                         targets,
                                                         nTrain,
                                                         dt=0.001,
                                                         tauRiseMax=tauRiseMax,
                                                         name="multiplyNew")
        np.savez("data/multiplyNew_%s.npz" % neuron_type,
                 d1=d1,
                 tauRise1=tauRise1,
                 tauFall1=tauFall1)
        times = np.arange(0, t * nTrain, 0.001)
        plotState(times, X, Y, error, "multiplyNew", "%s_pre" % neuron_type,
                  t * nTrain)

    if load:
        e1 = np.load(file)['e1']
    elif isinstance(neuron_type, Bio):
        print("ens1 encoders")
        e1 = np.zeros((NPre, N, 2))
        for n in range(nEnc):
            stim = makeSignal(t, f, dt=dt, seed=n)
            data = go(d1=d1,
                      e1=e1,
                      f1=f1,
                      NPre=NPre,
                      N=N,
                      N2=N2,
                      t=t,
                      dt=dt,
                      f=f,
                      fS=fS,
                      neuron_type=neuron_type,
                      stim=stim,
                      l1=True)
            e1 = data['e1']
            np.savez("data/multiplyNew_%s.npz" % neuron_type,
                     d1=d1,
                     tauRise1=tauRise1,
                     tauFall1=tauFall1,
                     e1=e1)
            plotActivity(t, dt, fS, data['times'], data['ens'], data['tarEns'],
                         "multiplyNew", "ens")
    else:
        e1 = np.zeros((NPre, N, 2))

    if load:
        d2 = np.load(file)['d2']
        tauRise2 = np.load(file)['tauRise2']
        tauFall2 = np.load(file)['tauFall2']
        f2 = DoubleExp(tauRise2, tauFall2)
    else:
        print('readout decoders for ens')
        spikes = np.zeros((nTrain, int(t / dt), N))
        targets = np.zeros((nTrain, int(t / dt), 1))
        for n in range(nTrain):
            stim = makeSignal(t, f, dt=dt, seed=n)
            data = go(d1=d1,
                      e1=e1,
                      f1=f1,
                      NPre=NPre,
                      N=N,
                      N2=N2,
                      t=t,
                      dt=dt,
                      f=f,
                      fS=fS,
                      neuron_type=neuron_type,
                      stim=stim)
            spikes[n] = data['ens']
            targets[n] = f.filt(data['tar'][:, 0] * data['tar'][:, 1],
                                dt=dt).reshape(-1, 1)
        d2, f2, tauRise2, tauFall2, X, Y, error = decode(spikes,
                                                         targets,
                                                         nTrain,
                                                         dt=dt,
                                                         tauRiseMax=tauRiseMax,
                                                         name="multiplyNew")
        np.savez("data/multiplyNew_%s.npz" % neuron_type,
                 d1=d1,
                 tauRise1=tauRise1,
                 tauFall1=tauFall1,
                 e1=e1,
                 d2=d2,
                 tauRise2=tauRise2,
                 tauFall2=tauFall2)
        times = np.arange(0, t * nTrain, dt)
        plotState(times, X, Y, error, "multiplyNew", "%s_ens" % neuron_type,
                  t * nTrain)

    if load:
        e2 = np.load(file)['e2']
    elif isinstance(neuron_type, Bio):
        print("ens2 encoders")
        e2 = np.zeros((N, N2, 1))
        for n in range(nEnc):
            stim = makeSignal(t, f, dt=dt, seed=n)
            data = go(d1=d1,
                      e1=e1,
                      f1=f1,
                      d2=d2,
                      e2=e2,
                      f2=f2,
                      NPre=NPre,
                      N=N,
                      N2=N2,
                      t=t,
                      dt=dt,
                      f=f,
                      fS=fS,
                      neuron_type=neuron_type,
                      stim=stim,
                      l2=True)
            e2 = data['e2']
            np.savez("data/multiplyNew_%s.npz" % neuron_type,
                     d1=d1,
                     tauRise1=tauRise1,
                     tauFall1=tauFall1,
                     e1=e1,
                     d2=d2,
                     tauRise2=tauRise2,
                     tauFall2=tauFall2,
                     e2=e2)
            plotActivity(t, dt, fS, data['times'], data['ens2'],
                         data['tarEns2'], "multiplyNew", "ens2")
    else:
        e2 = np.zeros((N, N2, 1))

    if load:
        d3 = np.load(file)['d3']
        tauRise3 = np.load(file)['tauRise3']
        tauFall3 = np.load(file)['tauFall3']
        f3 = DoubleExp(tauRise3, tauFall3)
    else:
        print('readout decoders for ens2')
        spikes = np.zeros((nTrain, int(t / dt), N2))
        targets = np.zeros((nTrain, int(t / dt), 1))
        for n in range(nTrain):
            stim = makeSignal(t, f, dt=dt, seed=n)
            data = go(d1=d1,
                      e1=e1,
                      f1=f1,
                      d2=d2,
                      e2=e2,
                      f2=f2,
                      NPre=NPre,
                      N=N,
                      N2=N2,
                      t=t,
                      dt=dt,
                      f=f,
                      fS=fS,
                      neuron_type=neuron_type,
                      stim=stim)
            spikes[n] = data['ens2']
            targets[n] = f.filt(data['tar2'], dt=dt)
        d3, f3, tauRise3, tauFall3, X, Y, error = decode(spikes,
                                                         targets,
                                                         nTrain,
                                                         dt=dt,
                                                         name="multiplyNew")
        np.savez("data/multiplyNew_%s.npz" % neuron_type,
                 d1=d1,
                 tauRise1=tauRise1,
                 tauFall1=tauFall1,
                 e1=e1,
                 d2=d2,
                 tauRise2=tauRise2,
                 tauFall2=tauFall2,
                 e2=e2,
                 d3=d3,
                 tauRise3=tauRise3,
                 tauFall3=tauFall3)
        times = np.arange(0, t * nTrain, dt)
        plotState(times, X, Y, error, "multiplyNew", "%s_ens2" % neuron_type,
                  t * nTrain)

    errors = np.zeros((nTest))
    print("testing")
    for test in range(nTest):
        stim = makeSignal(t, f, dt=dt, seed=100 + test)
        data = go(d1=d1,
                  e1=e1,
                  f1=f1,
                  d2=d2,
                  e2=e2,
                  f2=f2,
                  NPre=NPre,
                  N=N,
                  N2=N2,
                  t=t,
                  dt=dt,
                  f=f,
                  fS=fS,
                  neuron_type=neuron_type,
                  stim=stim)
        A = f3.filt(data['ens2'], dt=dt)
        X = np.dot(A, d3)
        Y = f.filt(data['tar2'], dt=dt)
        error = rmse(X, Y)
        errors[test] = error
        plotState(data['times'], X, Y, error, "multiplyNew",
                  "%s_test%s" % (neuron_type, test), t)
        A = f2.filt(data['ens'], dt=dt)
        X = np.dot(A, d2)
        Y = f.filt(data['tar'][:, 0] * data['tar'][:, 1], dt=dt).reshape(-1, 1)
        plotState(data['times'], X, Y, rmse(X, Y), "multiplyNew",
                  "%s_pretest%s" % (neuron_type, test), t)

    print('%s errors:' % neuron_type, errors)
    np.savez("data/multiplyNew_%s.npz" % neuron_type,
             d1=d1,
             tauRise1=tauRise1,
             tauFall1=tauFall1,
             e1=e1,
             d2=d2,
             tauRise2=tauRise2,
             tauFall2=tauFall2,
             e2=e2,
             d3=d3,
             tauRise3=tauRise3,
             tauFall3=tauFall3,
             errors=errors)
    return errors
Exemplo n.º 12
0
def goLIF(N=100,
          t=10,
          dt=0.001,
          stim=lambda t: 0,
          m=Uniform(30, 30),
          i=Uniform(0, 0.8),
          e=Choice([[1]]),
          i2=Uniform(0.4, 1),
          dFdfwEnsAMPA=None,
          dFdfwEnsNMDA=None,
          dEnsEnsAMPA=None,
          dEnsEnsNMDA=None,
          dEnsInhAMPA=None,
          dEnsInhNMDA=None,
          dInhFdfwGABA=None,
          dInhEnsGABA=None,
          fAMPA=DoubleExp(0.55e-3, 2.2e-3),
          fNMDA=DoubleExp(2.3e-3, 95.0e-3),
          fGABA=DoubleExp(0.5e-3, 1.5e-3),
          x0=0,
          seed=0):

    with nengo.Network(seed=seed) as model:
        inpt = nengo.Node(stim)
        intg = nengo.Ensemble(1, 1, neuron_type=nengo.Direct())
        fdfw = nengo.Ensemble(N, 1, max_rates=m, intercepts=i, seed=seed)
        inh = nengo.Ensemble(N,
                             1,
                             encoders=e,
                             max_rates=m,
                             intercepts=i2,
                             neuron_type=LIF(),
                             seed=seed)
        ens = nengo.Ensemble(N,
                             1,
                             encoders=e,
                             max_rates=m,
                             intercepts=i,
                             neuron_type=LIF(),
                             seed=seed)
        nengo.Connection(inpt, intg, synapse=1 / s)
        nengo.Connection(inpt, fdfw, synapse=None)
        nengo.Connection(fdfw,
                         ens,
                         synapse=fAMPA,
                         solver=NoSolver(dFdfwEnsAMPA))
        nengo.Connection(fdfw,
                         ens,
                         synapse=fNMDA,
                         solver=NoSolver(dFdfwEnsNMDA))
        nengo.Connection(ens, ens, synapse=fAMPA, solver=NoSolver(dEnsEnsAMPA))
        nengo.Connection(ens, ens, synapse=fNMDA, solver=NoSolver(dEnsEnsNMDA))
        nengo.Connection(ens, inh, synapse=fAMPA, solver=NoSolver(dEnsInhAMPA))
        nengo.Connection(ens, inh, synapse=fNMDA, solver=NoSolver(dEnsInhNMDA))
        nengo.Connection(inh,
                         fdfw,
                         synapse=fGABA,
                         solver=NoSolver(dInhFdfwGABA))
        nengo.Connection(inh, ens, synapse=fGABA, solver=NoSolver(dInhEnsGABA))
        pInpt = nengo.Probe(inpt, synapse=None)
        pIntg = nengo.Probe(intg, synapse=None)
        pFdfw = nengo.Probe(fdfw.neurons, synapse=None)
        pInh = nengo.Probe(inh.neurons, synapse=None)
        pEns = nengo.Probe(ens.neurons, synapse=None)
    with nengo.Simulator(model, seed=seed, progress_bar=False) as sim:
        #         init_lif(sim, ens, x0=x0)
        sim.run(t, progress_bar=True)
    return dict(times=sim.trange(),
                inpt=sim.data[pInpt],
                intg=sim.data[pIntg],
                fdfw=sim.data[pFdfw],
                inh=sim.data[pInh],
                ens=sim.data[pEns])
Exemplo n.º 13
0
def go(d_ens,
       f_ens,
       n_neurons=100,
       t=10,
       m=Uniform(30, 40),
       i=Uniform(-1, 0.6),
       seed=0,
       dt=0.001,
       neuron_type=nengo.LIF(),
       f=DoubleExp(1e-2, 2e-1),
       f_smooth=DoubleExp(1e-2, 2e-1),
       freq=1,
       w_ff=None,
       w_fb=None,
       e_ff=None,
       e_fb=None,
       L_ff=False,
       L_fb=False,
       L_fd=False,
       supervised=False):

    w = 2 * np.pi * freq
    A = [[0, -w], [w, 0]]
    B = [[1], [0]]
    C = [[1, 0]]
    D = [[0]]
    sys = LinearSystem((A, B, C, D))

    with nengo.Network(seed=seed) as model:

        u = nengo.Node(lambda t: [np.sin(w * t), np.cos(w * t)])

        # Ensembles
        pre = nengo.Ensemble(300,
                             2,
                             neuron_type=SpikingRectifiedLinear(),
                             seed=seed,
                             radius=2)
        ens = nengo.Ensemble(n_neurons,
                             2,
                             max_rates=m,
                             intercepts=i,
                             neuron_type=neuron_type,
                             seed=seed,
                             radius=2)
        nengo.Connection(u, pre, synapse=None, seed=seed)
        pre_ens = nengo.Connection(pre, ens, synapse=f, seed=seed)

        if L_ff:
            supv = nengo.Ensemble(n_neurons,
                                  2,
                                  max_rates=m,
                                  intercepts=i,
                                  neuron_type=LIF(),
                                  seed=seed,
                                  radius=2)
            nengo.Connection(pre, supv, synapse=f, seed=seed)
            node = LearningNode2(n_neurons, pre.n_neurons, pre_ens, k=3e-6)
            nengo.Connection(pre.neurons, node[0:pre.n_neurons], synapse=f)
            nengo.Connection(ens.neurons,
                             node[pre.n_neurons:pre.n_neurons + n_neurons],
                             synapse=f_smooth)
            nengo.Connection(supv.neurons,
                             node[pre.n_neurons + n_neurons:pre.n_neurons +
                                  2 * n_neurons],
                             synapse=f_smooth)
            nengo.Connection(u, node[-2:], synapse=f)
            p_supv = nengo.Probe(supv.neurons, synapse=None)

        if L_fb or supervised:
            supv = nengo.Ensemble(n_neurons,
                                  2,
                                  max_rates=m,
                                  intercepts=i,
                                  neuron_type=neuron_type,
                                  seed=seed,
                                  radius=2)
            #             supv2 = nengo.Ensemble(n_neurons, 2, max_rates=m, intercepts=i, neuron_type=neuron_type, seed=seed, radius=2)
            #             pre2 = nengo.Ensemble(300, 2, neuron_type=SpikingRectifiedLinear(), seed=seed, radius=2)
            #             nengo.Connection(u, pre2, synapse=f, seed=seed)
            pre_supv = nengo.Connection(pre, supv, synapse=f, seed=seed)
            #             pre2_supv2 = nengo.Connection(pre2, supv2, synapse=f, seed=seed)
            supv_ens = nengo.Connection(supv,
                                        ens,
                                        synapse=f_ens,
                                        seed=seed,
                                        solver=NoSolver(d_ens))
            p_supv = nengo.Probe(supv.neurons, synapse=None)
#             p_supv2 = nengo.Probe(supv2.neurons, synapse=None)

        if L_fb:
            node = LearningNode2(n_neurons, n_neurons, supv_ens, k=3e-6)
            nengo.Connection(supv.neurons, node[0:n_neurons], synapse=f_ens)
            nengo.Connection(ens.neurons,
                             node[n_neurons:2 * n_neurons],
                             synapse=f_smooth)
            #             nengo.Connection(supv2.neurons, node[2*n_neurons: 3*n_neurons], synapse=f_smooth)
            nengo.Connection(supv.neurons,
                             node[2 * n_neurons:3 * n_neurons],
                             synapse=f_smooth)
            nengo.Connection(u, node[-2:], synapse=f)

        if not L_ff and not L_fb and not L_fd and not supervised:
            off = nengo.Node(lambda t: (t > 1.0))
            nengo.Connection(off,
                             pre.neurons,
                             synapse=None,
                             transform=-1e3 * np.ones((pre.n_neurons, 1)))
            ens_ens = nengo.Connection(ens,
                                       ens,
                                       synapse=f_ens,
                                       seed=seed,
                                       solver=NoSolver(d_ens))

        # Probes
        p_u = nengo.Probe(u, synapse=None)
        p_ens = nengo.Probe(ens.neurons, synapse=None)

    with nengo.Simulator(model, seed=seed, dt=dt) as sim:
        if np.any(w_ff):
            for pre in range(pre.n_neurons):
                for post in range(n_neurons):
                    if L_fb or supervised:
                        pre_supv.weights[pre, post] = w_ff[pre, post]
                        pre_supv.netcons[pre,
                                         post].weight[0] = np.abs(w_ff[pre,
                                                                       post])
                        pre_supv.netcons[
                            pre,
                            post].syn().e = 0 if w_ff[pre, post] > 0 else -70


#                         pre2_supv2.weights[pre, post] = w_ff[pre, post]
#                         pre2_supv2.netcons[pre, post].weight[0] = np.abs(w_ff[pre, post])
#                         pre2_supv2.netcons[pre, post].syn().e = 0 if w_ff[pre, post] > 0 else -70
                    else:
                        pre_ens.weights[pre, post] = w_ff[pre, post]
                        pre_ens.netcons[pre,
                                        post].weight[0] = np.abs(w_ff[pre,
                                                                      post])
                        pre_ens.netcons[
                            pre,
                            post].syn().e = 0 if w_ff[pre, post] > 0 else -70
        if np.any(e_ff) and L_ff:
            pre_ens.e = e_ff
        if np.any(w_fb):
            for pre in range(n_neurons):
                for post in range(n_neurons):
                    if L_fb or supervised:
                        supv_ens.weights[pre, post] = w_fb[pre, post]
                        supv_ens.netcons[pre,
                                         post].weight[0] = np.abs(w_fb[pre,
                                                                       post])
                        supv_ens.netcons[
                            pre,
                            post].syn().e = 0 if w_fb[pre, post] > 0 else -70
                    else:
                        ens_ens.weights[pre, post] = w_fb[pre, post]
                        ens_ens.netcons[pre,
                                        post].weight[0] = np.abs(w_fb[pre,
                                                                      post])
                        ens_ens.netcons[
                            pre,
                            post].syn().e = 0 if w_fb[pre, post] > 0 else -70
        if np.any(e_fb) and L_fb:
            supv_ens.e = e_fb

        neuron.h.init()
        sim.run(t)
        reset_neuron(sim, model)

    if L_ff and hasattr(pre_ens, 'weights'):
        w_ff = pre_ens.weights
        e_ff = pre_ens.e
    if L_fb and hasattr(supv_ens, 'weights'):
        w_fb = supv_ens.weights
        e_fb = supv_ens.e

    return dict(
        times=sim.trange(),
        u=sim.data[p_u],
        ens=sim.data[p_ens],
        supv=sim.data[p_supv] if L_ff or L_fb or supervised else None,
        #         supv2=sim.data[p_supv2] if L_fb or supervised else None,
        w_ff=w_ff,
        w_fb=w_fb,
        e_ff=e_ff,
        e_fb=e_fb,
    )
Exemplo n.º 14
0
def go(NPre=100,
       N=100,
       t=10,
       m=Uniform(30, 30),
       i=Uniform(-0.8, 0.8),
       seed=0,
       dt=0.001,
       f=DoubleExp(1e-3, 1e-2),
       fS=DoubleExp(1e-3, 1e-1),
       neuron_type=LIF(),
       d1a=None,
       d1b=None,
       d2=None,
       f1a=None,
       f1b=None,
       f2=None,
       e1a=None,
       e1b=None,
       e2=None,
       l1a=False,
       l1b=False,
       l2=False,
       l3=False,
       test=False,
       stim=lambda t: np.sin(t),
       stim2=lambda t: 0):

    with nengo.Network(seed=seed) as model:
        inpt = nengo.Node(stim)
        intg = nengo.Ensemble(1, 1, neuron_type=nengo.Direct())
        preInpt = nengo.Ensemble(NPre, 1, radius=3, max_rates=m, seed=seed)
        preIntg = nengo.Ensemble(NPre, 1, max_rates=m, seed=seed)
        ens = nengo.Ensemble(N,
                             1,
                             max_rates=m,
                             intercepts=i,
                             neuron_type=neuron_type,
                             seed=seed)
        nengo.Connection(inpt, intg, synapse=1 / s, seed=seed)
        c0a = nengo.Connection(inpt, preInpt, synapse=None, seed=seed)
        c0b = nengo.Connection(intg, preIntg, synapse=None, seed=seed)
        c1a = nengo.Connection(preInpt,
                               ens,
                               synapse=f1a,
                               solver=NoSolver(d1a),
                               seed=seed)
        pInpt = nengo.Probe(inpt, synapse=None)
        pIntg = nengo.Probe(intg, synapse=None)
        pPreInpt = nengo.Probe(preInpt.neurons, synapse=None)
        pPreIntg = nengo.Probe(preIntg.neurons, synapse=None)
        pEns = nengo.Probe(ens.neurons, synapse=None)
        if l1b:  # preIntg-to-ens
            tarEns = nengo.Ensemble(N,
                                    1,
                                    max_rates=m,
                                    intercepts=i,
                                    neuron_type=nengo.LIF(),
                                    seed=seed)
            nengo.Connection(preIntg,
                             tarEns,
                             synapse=f1b,
                             solver=NoSolver(d1b),
                             seed=seed + 1)
            c1b = nengo.Connection(preIntg,
                                   ens,
                                   synapse=f1b,
                                   solver=NoSolver(d1b),
                                   seed=seed + 1)
            learnEncoders(c1b, tarEns, fS)
            pTarEns = nengo.Probe(tarEns.neurons, synapse=None)
        if l1a:  # preInpt-to-ens, given preIntg-to-ens
            inpt2 = nengo.Node(stim2)
            tarEns = nengo.Ensemble(N,
                                    1,
                                    max_rates=m,
                                    intercepts=i,
                                    neuron_type=nengo.LIF(),
                                    seed=seed)
            nengo.Connection(preInpt,
                             tarEns,
                             synapse=f1a,
                             solver=NoSolver(d1a),
                             seed=seed)
            c0b.transform = 0
            nengo.Connection(inpt2, preIntg, synapse=None, seed=seed)
            nengo.Connection(preIntg,
                             tarEns,
                             synapse=f1b,
                             solver=NoSolver(d1b),
                             seed=seed + 1)
            c1b = nengo.Connection(preIntg,
                                   ens,
                                   synapse=f1b,
                                   solver=NoSolver(d1b),
                                   seed=seed + 1)
            learnEncoders(c1a, tarEns, fS)
            pTarEns = nengo.Probe(tarEns.neurons, synapse=None)
        if l2:  # ens readout, given preIntg and preInpt
            c1b = nengo.Connection(preIntg,
                                   ens,
                                   synapse=f1b,
                                   solver=NoSolver(d1b),
                                   seed=seed + 1)
        if l3:  # ens2-to-ens, given preInpt-ens and preIntg-ens2
            ens2 = nengo.Ensemble(N,
                                  1,
                                  max_rates=m,
                                  intercepts=i,
                                  neuron_type=neuron_type,
                                  seed=seed)
            c0a.synapse = f
            c2a = nengo.Connection(preInpt,
                                   ens2,
                                   synapse=f1a,
                                   solver=NoSolver(d1a),
                                   seed=seed)
            c2b = nengo.Connection(preIntg,
                                   ens2,
                                   synapse=f1b,
                                   solver=NoSolver(d1b),
                                   seed=seed + 1)
            c3 = nengo.Connection(ens2,
                                  ens,
                                  synapse=f2,
                                  solver=NoSolver(d2),
                                  seed=seed)
            learnEncoders(c3, ens2, fS)
            pTarEns2 = nengo.Probe(ens2.neurons, synapse=None)
        if test:
            c5 = nengo.Connection(ens,
                                  ens,
                                  synapse=f2,
                                  solver=NoSolver(d2),
                                  seed=seed)

    with nengo.Simulator(model, seed=seed, dt=dt, progress_bar=False) as sim:
        if isinstance(neuron_type, Bio):
            if l1b:
                setWeights(c1b, d1b, e1b)
            if l1a:
                setWeights(c1a, d1a, e1a)
                setWeights(c1b, d1b, e1b)
            if l2:
                setWeights(c1a, d1a, e1a)
                setWeights(c1b, d1b, e1b)
            if l3:
                setWeights(c1a, d1a, e1a)
                setWeights(c2a, d1a, e1a)
                setWeights(c2b, d1b, e1b)
                setWeights(c3, d2, e2)
            if test:
                setWeights(c1a, d1a, e1a)
                setWeights(c5, d2, e2)
            neuron.h.init()
            sim.run(t, progress_bar=True)
            reset_neuron(sim, model)
        else:
            sim.run(t, progress_bar=True)

    e1a = c1a.e if l1a else e1a
    e1b = c1b.e if l1b else e1b
    e2 = c3.e if l3 else e2

    return dict(
        times=sim.trange(),
        inpt=sim.data[pInpt],
        intg=sim.data[pIntg],
        preInpt=sim.data[pPreInpt],
        preIntg=sim.data[pPreIntg],
        ens=sim.data[pEns],
        tarEns=sim.data[pTarEns] if l1a or l1b else None,
        tarEns2=sim.data[pTarEns2] if l3 else None,
        e1a=e1a,
        e1b=e1b,
        e2=e2,
    )
Exemplo n.º 15
0
def run(NPre=100,
        N=100,
        t=10,
        nTrain=10,
        nTest=30,
        nEnc=10,
        neuron_type=LIF(),
        dt=0.001,
        f=DoubleExp(1e-3, 1e-1),
        fS=DoubleExp(1e-3, 1e-1),
        Tff=0.1,
        Tfb=1.0,
        reg=1e-3,
        tauRiseMax=5e-2,
        tauFallMax=3e-1,
        load=False,
        file=None):

    print('\nNeuron Type: %s' % neuron_type)
    if load:
        d1a = np.load(file)['d1a']
        d1b = np.load(file)['d1b']
        tauRise1a = np.load(file)['tauRise1a']
        tauRise1b = np.load(file)['tauRise1b']
        tauFall1a = np.load(file)['tauFall1a']
        tauFall1b = np.load(file)['tauFall1b']
        f1a = DoubleExp(tauRise1a, tauFall1a)
        f1b = DoubleExp(tauRise1b, tauFall1b)
    else:
        print('readout decoders for preInpt and preIntg')
        spikesInpt = np.zeros((nTrain, int(t / 0.001), NPre))
        spikesIntg = np.zeros((nTrain, int(t / 0.001), NPre))
        targetsInpt = np.zeros((nTrain, int(t / 0.001), 1))
        targetsIntg = np.zeros((nTrain, int(t / 0.001), 1))
        for n in range(nTrain):
            stim = makeSignal(t, f, dt=0.001, seed=n)
            data = go(NPre=NPre,
                      N=N,
                      t=t,
                      dt=0.001,
                      f=f,
                      fS=fS,
                      neuron_type=LIF(),
                      stim=stim)
            spikesInpt[n] = data['preInpt']
            spikesIntg[n] = data['preIntg']
            targetsInpt[n] = f.filt(Tff * data['inpt'], dt=0.001)
            targetsIntg[n] = f.filt(data['intg'], dt=0.001)
        d1a, f1a, tauRise1a, tauFall1a, X1a, Y1a, error1a = decode(
            spikesInpt,
            targetsInpt,
            nTrain,
            dt=0.001,
            reg=reg,
            name="integrateNew",
            tauRiseMax=tauRiseMax,
            tauFallMax=tauFallMax)
        d1b, f1b, tauRise1b, tauFall1b, X1b, Y1b, error1b = decode(
            spikesIntg,
            targetsIntg,
            nTrain,
            dt=0.001,
            reg=reg,
            name="integrateNew",
            tauRiseMax=tauRiseMax,
            tauFallMax=tauFallMax)
        np.savez("data/integrateNew_%s.npz" % neuron_type,
                 d1a=d1a,
                 d1b=d1b,
                 tauRise1a=tauRise1a,
                 tauRise1b=tauRise1b,
                 tauFall1a=tauFall1a,
                 tauFall1b=tauFall1b)
        times = np.arange(0, t * nTrain, 0.001)
        plotState(times, X1a, Y1a, error1a, "integrateNew",
                  "%s_preInpt" % neuron_type, t * nTrain)
        plotState(times, X1b, Y1b, error1b, "integrateNew",
                  "%s_preIntg" % neuron_type, t * nTrain)

    if load:
        e1a = np.load(file)['e1a']
        e1b = np.load(file)['e1b']
    elif isinstance(neuron_type, Bio):
        e1a = np.zeros((NPre, N, 1))
        e1b = np.zeros((NPre, N, 1))
        print("encoders for preIntg-to-ens")
        for n in range(nEnc):
            stim = makeSignal(t, f, dt=dt, seed=n)
            data = go(d1a=d1a,
                      d1b=d1b,
                      e1a=e1a,
                      e1b=e1b,
                      f1a=f1a,
                      f1b=f1b,
                      NPre=NPre,
                      N=N,
                      t=t,
                      dt=dt,
                      f=f,
                      fS=fS,
                      neuron_type=neuron_type,
                      stim=stim,
                      l1b=True)
            e1b = data['e1b']
            plotActivity(t, dt, fS, data['times'], data['ens'], data['tarEns'],
                         "integrateNew", "preIntgToEns")
            np.savez("data/integrateNew_%s.npz" % neuron_type,
                     d1a=d1a,
                     d1b=d1b,
                     tauRise1a=tauRise1a,
                     tauRise1b=tauRise1b,
                     tauFall1a=tauFall1a,
                     tauFall1b=tauFall1b,
                     e1a=e1a,
                     e1b=e1b)
        print("encoders for preInpt-to-ens")
        for n in range(nEnc):
            stim = makeSignal(t, f, dt=dt, seed=n)
            # stim2 = makeSignal(t, f, dt=dt, seed=n, value=0.5)
            #             stim2 = makeSignal(t, f, dt=dt, norm='u', freq=0.25, value=0.8, seed=100+n)
            stim2 = makeSin(t, f, dt=dt, seed=n)
            data = go(d1a=d1a,
                      d1b=d1b,
                      e1a=e1a,
                      e1b=e1b,
                      f1a=f1a,
                      f1b=f1b,
                      NPre=NPre,
                      N=N,
                      t=t,
                      dt=dt,
                      f=f,
                      fS=fS,
                      neuron_type=neuron_type,
                      stim=stim,
                      l1a=True,
                      stim2=stim2)
            e1a = data['e1a']
            plotActivity(t, dt, fS, data['times'], data['ens'], data['tarEns'],
                         "integrateNew", "preInptToEns")
            np.savez("data/integrateNew_%s.npz" % neuron_type,
                     d1a=d1a,
                     d1b=d1b,
                     tauRise1a=tauRise1a,
                     tauRise1b=tauRise1b,
                     tauFall1a=tauFall1a,
                     tauFall1b=tauFall1b,
                     e1a=e1a,
                     e1b=e1b)
    else:
        e1a = np.zeros((NPre, N, 1))
        e1b = np.zeros((NPre, N, 1))

    if load:
        d2 = np.load(file)['d2']
        tauRise2 = np.load(file)['tauRise2']
        tauFall2 = np.load(file)['tauFall2']
        f2 = DoubleExp(tauRise2, tauFall2)
    else:
        print('readout decoders for ens')
        spikes = np.zeros((nTrain, int(t / dt), N))
        targets = np.zeros((nTrain, int(t / dt), 1))
        for n in range(nTrain):
            stim = makeSignal(t, f, dt=dt, seed=n)
            data = go(d1a=d1a,
                      d1b=d1b,
                      e1a=e1a,
                      e1b=e1b,
                      f1a=f1a,
                      f1b=f1b,
                      NPre=NPre,
                      N=N,
                      t=t,
                      dt=dt,
                      f=f,
                      fS=fS,
                      neuron_type=neuron_type,
                      stim=stim,
                      l2=True)
            spikes[n] = data['ens']
            targets[n] = f.filt(Tfb * data['intg'], dt=dt)
        d2, f2, tauRise2, tauFall2, X, Y, error = decode(spikes,
                                                         targets,
                                                         nTrain,
                                                         dt=dt,
                                                         reg=reg,
                                                         tauRiseMax=tauRiseMax,
                                                         tauFallMax=tauFallMax,
                                                         name="integrateNew")
        np.savez("data/integrateNew_%s.npz" % neuron_type,
                 d1a=d1a,
                 d1b=d1b,
                 tauRise1a=tauRise1a,
                 tauRise1b=tauRise1b,
                 tauFall1a=tauFall1a,
                 tauFall1b=tauFall1b,
                 e1a=e1a,
                 e1b=e1b,
                 d2=d2,
                 tauRise2=tauRise2,
                 tauFall2=tauFall2)
        times = np.arange(0, t * nTrain, dt)
        plotState(times, X, Y, error, "integrateNew", "%s_ens" % neuron_type,
                  t * nTrain)

    if load:
        e2 = np.load(file)['e2']
    elif isinstance(neuron_type, Bio):
        print("encoders from ens2 to ens")
        e2 = np.zeros((N, N, 1))
        for n in range(nEnc):
            stim = makeSignal(t, f, dt=dt, seed=n)
            #stim = makeSin(t, f, dt=dt, seed=n)
            data = go(d1a=d1a,
                      d1b=d1b,
                      d2=d2,
                      e1a=e1a,
                      e1b=e1b,
                      e2=e2,
                      f1a=f1a,
                      f1b=f1b,
                      f2=f2,
                      NPre=NPre,
                      N=N,
                      t=t,
                      dt=dt,
                      f=f,
                      fS=fS,
                      neuron_type=neuron_type,
                      stim=stim,
                      l3=True)
            e2 = data['e2']
            plotActivity(t, dt, fS, data['times'], data['ens'],
                         data['tarEns2'], "integrateNew", "Ens2Ens")
            np.savez("data/integrateNew_%s.npz" % neuron_type,
                     d1a=d1a,
                     d1b=d1b,
                     tauRise1a=tauRise1a,
                     tauRise1b=tauRise1b,
                     tauFall1a=tauFall1a,
                     tauFall1b=tauFall1b,
                     e1a=e1a,
                     e1b=e1b,
                     d2=d2,
                     tauRise2=tauRise2,
                     tauFall2=tauFall2,
                     e2=e2)
    else:
        e2 = np.zeros((N, N, 1))

    print("testing")
    errors = np.zeros((nTest))
    for test in range(nTest):
        stim = makeSignal(t, f, dt=dt, seed=200 + test)
        data = go(d1a=d1a,
                  d2=d2,
                  e1a=e1a,
                  e2=e2,
                  f1a=f1a,
                  f2=f2,
                  NPre=NPre,
                  N=N,
                  t=t,
                  dt=dt,
                  f=f,
                  fS=fS,
                  neuron_type=neuron_type,
                  stim=stim,
                  test=True)
        A = f2.filt(data['ens'], dt=dt)
        X = np.dot(A, d2)
        Y = f.filt(data['intg'], dt=dt)
        U = f.filt(f.filt(Tff * data['inpt'], dt=dt))
        error = rmse(X, Y)
        errorU = rmse(X, U)
        errors[test] = error
        plotState(data['times'], X, Y, error, "integrateNew",
                  "%s_test%s" % (neuron_type, test), t)
        #plotState(data['times'], X, U, errorU, "integrateNew", "%s_inpt%s"%(neuron_type, test), t)
    print('%s errors:' % neuron_type, errors)
    np.savez("data/integrateNew_%s.npz" % neuron_type,
             d1a=d1a,
             d1b=d1b,
             tauRise1a=tauRise1a,
             tauRise1b=tauRise1b,
             tauFall1a=tauFall1a,
             tauFall1b=tauFall1b,
             e1a=e1a,
             e1b=e1b,
             d2=d2,
             tauRise2=tauRise2,
             tauFall2=tauFall2,
             e2=e2,
             errors=errors)
    return errors
Exemplo n.º 16
0
def go(NPre=300,
       NBias=100,
       N=30,
       t=10,
       seed=1,
       dt=0.001,
       Tff=0.3,
       tTrans=0.01,
       stage=None,
       alpha=3e-7,
       eMax=1e-1,
       fPre=DoubleExp(1e-3, 1e-1),
       fNMDA=DoubleExp(10.6e-3, 285e-3),
       fGABA=DoubleExp(0.5e-3, 1.5e-3),
       fS=DoubleExp(1e-3, 1e-1),
       dPre=None,
       dFdfw=None,
       dEns=None,
       dBias=None,
       ePreFdfw=None,
       ePreEns=None,
       ePreBias=None,
       eFdfwEns=None,
       eBiasEns=None,
       eEnsEns=None,
       stimA=lambda t: 0,
       stimB=lambda t: 0,
       stimC=lambda t: 0.01,
       DA=lambda t: 0):

    with nengo.Network(seed=seed) as model:
        inptA = nengo.Node(stimA)
        inptB = nengo.Node(stimB)
        inptC = nengo.Node(stimC)
        preA = nengo.Ensemble(NPre, 1, max_rates=Uniform(30, 30), seed=seed)
        preB = nengo.Ensemble(NPre, 1, max_rates=Uniform(30, 30), seed=seed)
        preC = nengo.Ensemble(NPre, 1, max_rates=Uniform(30, 30), seed=seed)
        fdfw = nengo.Ensemble(N,
                              1,
                              neuron_type=Bio("Pyramidal", DA=DA),
                              seed=seed)
        ens = nengo.Ensemble(N,
                             1,
                             neuron_type=Bio("Pyramidal", DA=DA),
                             seed=seed + 1)
        bias = nengo.Ensemble(NBias,
                              1,
                              neuron_type=Bio("Interneuron", DA=DA),
                              seed=seed + 2)
        tarFdfw = nengo.Ensemble(N,
                                 1,
                                 max_rates=Uniform(30, 30),
                                 intercepts=Uniform(-0.8, 0.8),
                                 seed=seed)
        tarEns = nengo.Ensemble(N,
                                1,
                                max_rates=Uniform(30, 30),
                                intercepts=Uniform(-0.8, 0.8),
                                seed=seed + 1)
        tarBias = nengo.Ensemble(NBias,
                                 1,
                                 max_rates=Uniform(30, 30),
                                 intercepts=Uniform(-0.8, -0.2),
                                 encoders=Choice([[1]]),
                                 seed=seed + 2)
        cA = nengo.Connection(inptA, preA, synapse=None, seed=seed)
        cB = nengo.Connection(inptB, preB, synapse=None, seed=seed)
        cC = nengo.Connection(inptC, preC, synapse=None, seed=seed)
        pInptA = nengo.Probe(inptA, synapse=None)
        pInptB = nengo.Probe(inptB, synapse=None)
        pInptC = nengo.Probe(inptC, synapse=None)
        pPreA = nengo.Probe(preA.neurons, synapse=None)
        pPreB = nengo.Probe(preB.neurons, synapse=None)
        pPreC = nengo.Probe(preC.neurons, synapse=None)
        pFdfw = nengo.Probe(fdfw.neurons, synapse=None)
        pTarFdfw = nengo.Probe(tarFdfw.neurons, synapse=None)
        pEns = nengo.Probe(ens.neurons, synapse=None)
        pTarEns = nengo.Probe(tarEns.neurons, synapse=None)
        pBias = nengo.Probe(bias.neurons, synapse=None)
        pTarBias = nengo.Probe(tarBias.neurons, synapse=None)
        if stage == 0:  # readout decoders for [preA, preB, preC]
            pass
        if stage == 1:  # encoders for [preA, preC] to [fdfw, bias]
            nengo.Connection(inptA, tarFdfw, synapse=fPre, seed=seed)
            nengo.Connection(inptC, tarBias, synapse=fPre, seed=seed)
            c1 = nengo.Connection(preA,
                                  fdfw,
                                  synapse=fPre,
                                  solver=NoSolver(dPre),
                                  seed=seed)
            c2 = nengo.Connection(preC,
                                  bias,
                                  synapse=fPre,
                                  solver=NoSolver(dPre),
                                  seed=seed)
            learnEncoders(c1,
                          tarFdfw,
                          fS,
                          alpha=3 * alpha,
                          eMax=3 * eMax,
                          tTrans=tTrans)
            learnEncoders(c2,
                          tarBias,
                          fS,
                          alpha=alpha,
                          eMax=eMax,
                          tTrans=tTrans)
        if stage == 2:  # readout decoders for fdfw and bias
            c1 = nengo.Connection(preA,
                                  fdfw,
                                  synapse=fPre,
                                  solver=NoSolver(dPre),
                                  seed=seed)
            c2 = nengo.Connection(preC,
                                  bias,
                                  synapse=fPre,
                                  solver=NoSolver(dPre),
                                  seed=seed)
        if stage == 3:  # encoders for [bias] to ens
            # nengo.Connection(inptC, tarBias, synapse=fPre, seed=seed)
            nengo.Connection(inptC, tarEns, synapse=fGABA, seed=seed)
            c1 = nengo.Connection(preC,
                                  bias,
                                  synapse=fPre,
                                  solver=NoSolver(dPre),
                                  seed=seed)
            c2 = nengo.Connection(bias,
                                  ens,
                                  synapse=GABA(),
                                  solver=NoSolver(dBias),
                                  seed=seed)
            learnEncoders(c2,
                          tarEns,
                          fS,
                          alpha=1e3 * alpha,
                          eMax=1e3 * eMax,
                          tTrans=tTrans)
        if stage == 4:  # encoders for [preB] to [ens]
            # nengo.Connection(inptC, tarBias, synapse=fPre, seed=seed)
            nengo.Connection(inptC, tarEns, synapse=fGABA, seed=seed)
            nengo.Connection(inptB, tarEns, synapse=fPre, seed=seed)
            c1 = nengo.Connection(preC,
                                  bias,
                                  synapse=fPre,
                                  solver=NoSolver(dPre),
                                  seed=seed)
            c2 = nengo.Connection(bias,
                                  ens,
                                  synapse=GABA(),
                                  solver=NoSolver(dBias),
                                  seed=seed)
            c3 = nengo.Connection(preB,
                                  ens,
                                  synapse=fPre,
                                  solver=NoSolver(dPre),
                                  seed=seed)
            learnEncoders(c3,
                          tarEns,
                          fS,
                          alpha=3 * alpha,
                          eMax=3 * eMax,
                          tTrans=tTrans)
        if stage == 5:  # encoders for [fdfw] to ens
            cB.synapse = fNMDA
            tarPreA = nengo.Ensemble(1, 1, neuron_type=nengo.Direct())
            tarPreB = nengo.Ensemble(1, 1, neuron_type=nengo.Direct())
            tarPreC = nengo.Ensemble(1, 1, neuron_type=nengo.Direct())
            nengo.Connection(inptA, tarPreA, synapse=fPre, seed=seed)
            nengo.Connection(inptB, tarPreB, synapse=fNMDA, seed=seed)
            nengo.Connection(inptC, tarPreC, synapse=fPre, seed=seed)
            nengo.Connection(tarPreA,
                             tarEns,
                             synapse=fNMDA,
                             transform=Tff,
                             seed=seed)
            nengo.Connection(tarPreB, tarEns, synapse=fPre, seed=seed)
            nengo.Connection(tarPreC, tarEns, synapse=fGABA, seed=seed)
            c1 = nengo.Connection(preA,
                                  fdfw,
                                  synapse=fPre,
                                  solver=NoSolver(dPre),
                                  seed=seed)
            c2 = nengo.Connection(preB,
                                  ens,
                                  synapse=fPre,
                                  solver=NoSolver(dPre),
                                  seed=seed)
            c3 = nengo.Connection(preC,
                                  bias,
                                  synapse=fPre,
                                  solver=NoSolver(dPre),
                                  seed=seed)
            c4 = nengo.Connection(fdfw,
                                  ens,
                                  synapse=NMDA(),
                                  solver=NoSolver(dFdfw),
                                  seed=seed)
            c5 = nengo.Connection(bias,
                                  ens,
                                  synapse=GABA(),
                                  solver=NoSolver(dBias),
                                  seed=seed)
            learnEncoders(c4,
                          tarEns,
                          fS,
                          alpha=3 * alpha,
                          eMax=3 * eMax,
                          tTrans=tTrans)
        if stage == 6:  # readout decoders for ens
            cB.synapse = fNMDA
            c1 = nengo.Connection(preA,
                                  fdfw,
                                  synapse=fPre,
                                  solver=NoSolver(dPre),
                                  seed=seed)
            c2 = nengo.Connection(preB,
                                  ens,
                                  synapse=fPre,
                                  solver=NoSolver(dPre),
                                  seed=seed)
            c3 = nengo.Connection(preC,
                                  bias,
                                  synapse=fPre,
                                  solver=NoSolver(dPre),
                                  seed=seed)
            c4 = nengo.Connection(fdfw,
                                  ens,
                                  synapse=NMDA(),
                                  solver=NoSolver(dFdfw),
                                  seed=seed)
            c5 = nengo.Connection(bias,
                                  ens,
                                  synapse=GABA(),
                                  solver=NoSolver(dBias),
                                  seed=seed)
        if stage == 7:  # encoders from ens to ens
            preB2 = nengo.Ensemble(NPre,
                                   1,
                                   max_rates=Uniform(30, 30),
                                   seed=seed)
            ens2 = nengo.Ensemble(N,
                                  1,
                                  neuron_type=Bio("Pyramidal", DA=DA),
                                  seed=seed + 1)  # acts as preB input to ens
            ens3 = nengo.Ensemble(N,
                                  1,
                                  neuron_type=Bio("Pyramidal", DA=DA),
                                  seed=seed + 1)  # acts as tarEns
            nengo.Connection(inptB, preB2, synapse=fNMDA, seed=seed)
            pTarEns = nengo.Probe(ens3.neurons, synapse=None)
            c1 = nengo.Connection(preA,
                                  fdfw,
                                  synapse=fPre,
                                  solver=NoSolver(dPre),
                                  seed=seed)
            c2 = nengo.Connection(preB,
                                  ens2,
                                  synapse=fPre,
                                  solver=NoSolver(dPre),
                                  seed=seed)
            c3 = nengo.Connection(preB2,
                                  ens3,
                                  synapse=fPre,
                                  solver=NoSolver(dPre),
                                  seed=seed)
            c4 = nengo.Connection(preC,
                                  bias,
                                  synapse=fPre,
                                  solver=NoSolver(dPre),
                                  seed=seed)
            c5 = nengo.Connection(fdfw,
                                  ens,
                                  synapse=NMDA(),
                                  solver=NoSolver(dFdfw),
                                  seed=seed)
            c6 = nengo.Connection(fdfw,
                                  ens3,
                                  synapse=NMDA(),
                                  solver=NoSolver(dFdfw),
                                  seed=seed)
            c7 = nengo.Connection(bias,
                                  ens,
                                  synapse=GABA(),
                                  solver=NoSolver(dBias),
                                  seed=seed)
            c8 = nengo.Connection(bias,
                                  ens2,
                                  synapse=GABA(),
                                  solver=NoSolver(dBias),
                                  seed=seed)
            c9 = nengo.Connection(bias,
                                  ens3,
                                  synapse=GABA(),
                                  solver=NoSolver(dBias),
                                  seed=seed)
            c10 = nengo.Connection(ens2,
                                   ens,
                                   synapse=NMDA(),
                                   solver=NoSolver(dEns),
                                   seed=seed)
            learnEncoders(c10, ens3, fS, alpha=alpha, eMax=eMax, tTrans=tTrans)
        if stage == 8:  # test
            c1 = nengo.Connection(preA,
                                  fdfw,
                                  synapse=fPre,
                                  solver=NoSolver(dPre),
                                  seed=seed)
            c2 = nengo.Connection(preC,
                                  bias,
                                  synapse=fPre,
                                  solver=NoSolver(dPre),
                                  seed=seed)
            c3 = nengo.Connection(fdfw,
                                  ens,
                                  synapse=NMDA(),
                                  solver=NoSolver(dFdfw),
                                  seed=seed)
            c4 = nengo.Connection(bias,
                                  ens,
                                  synapse=GABA(),
                                  solver=NoSolver(dBias),
                                  seed=seed)
            c5 = nengo.Connection(ens,
                                  ens,
                                  synapse=NMDA(),
                                  solver=NoSolver(dEns),
                                  seed=seed)

    with nengo.Simulator(model, seed=seed, dt=dt, progress_bar=False) as sim:
        if stage == 0:
            pass
        if stage == 1:
            setWeights(c1, dPre, ePreFdfw)
            setWeights(c2, dPre, ePreBias)
        if stage == 2:
            setWeights(c1, dPre, ePreFdfw)
            setWeights(c2, dPre, ePreBias)
        if stage == 3:
            setWeights(c1, dPre, ePreBias)
            setWeights(c2, dBias, eBiasEns)
        if stage == 4:
            setWeights(c1, dPre, ePreBias)
            setWeights(c2, dBias, eBiasEns)
            setWeights(c3, dPre, ePreEns)
        if stage == 5:
            setWeights(c1, dPre, ePreFdfw)
            setWeights(c2, dPre, ePreEns)
            setWeights(c3, dPre, ePreBias)
            setWeights(c4, dFdfw, eFdfwEns)
            setWeights(c5, dBias, eBiasEns)
        if stage == 6:
            setWeights(c1, dPre, ePreFdfw)
            setWeights(c2, dPre, ePreEns)
            setWeights(c3, dPre, ePreBias)
            setWeights(c4, dFdfw, eFdfwEns)
            setWeights(c5, dBias, eBiasEns)
        if stage == 7:
            setWeights(c1, dPre, ePreFdfw)
            setWeights(c2, dPre, ePreEns)
            setWeights(c3, dPre, ePreEns)
            setWeights(c4, dPre, ePreBias)
            setWeights(c5, dFdfw, eFdfwEns)
            setWeights(c6, dFdfw, eFdfwEns)
            setWeights(c7, dBias, eBiasEns)
            setWeights(c8, dBias, eBiasEns)
            setWeights(c9, dBias, eBiasEns)
            setWeights(c10, dEns, eEnsEns)
        if stage == 8:
            setWeights(c1, dPre, ePreFdfw)
            setWeights(c2, dPre, ePreBias)
            setWeights(c3, dFdfw, eFdfwEns)
            setWeights(c4, dBias, eBiasEns)
            setWeights(c5, dEns, eEnsEns)

        neuron.h.init()
        sim.run(t, progress_bar=True)
        reset_neuron(sim, model)

    ePreFdfw = c1.e if stage == 1 else ePreFdfw
    ePreBias = c2.e if stage == 1 else ePreBias
    eBiasEns = c2.e if stage == 3 else eBiasEns
    ePreEns = c3.e if stage == 4 else ePreEns
    eFdfwEns = c4.e if stage == 5 else eFdfwEns
    eEnsEns = c10.e if stage == 7 else eEnsEns

    return dict(
        times=sim.trange(),
        inptA=sim.data[pInptA],
        inptB=sim.data[pInptB],
        inptC=sim.data[pInptC],
        preA=sim.data[pPreA],
        fdfw=sim.data[pFdfw],
        ens=sim.data[pEns],
        bias=sim.data[pBias],
        tarFdfw=sim.data[pTarFdfw],
        tarEns=sim.data[pTarEns],
        tarBias=sim.data[pTarBias],
        ePreFdfw=ePreFdfw,
        ePreEns=ePreEns,
        ePreBias=ePreBias,
        eFdfwEns=eFdfwEns,
        eBiasEns=eBiasEns,
        eEnsEns=eEnsEns,
    )
Exemplo n.º 17
0
def run(NPre=300,
        NBias=100,
        N=30,
        t=10,
        nTrain=10,
        nTest=5,
        nEnc=10,
        dt=0.001,
        fPre=DoubleExp(1e-3, 1e-1),
        fNMDA=DoubleExp(10.6e-3, 285e-3),
        fGABA=DoubleExp(0.5e-3, 1.5e-3),
        fS=DoubleExp(2e-2, 1e-1),
        DATrain=lambda t: 0,
        DATest=lambda t: 0,
        Tff=0.3,
        reg=1e-2,
        load=[],
        file=None):

    if 0 in load:
        dPre = np.load(file)['dPre']
    else:
        print('readout decoders for [preA, preB, preC]')
        spikesInpt = np.zeros((nTrain, int(t / dt), NPre))
        targetsInpt = np.zeros((nTrain, int(t / dt), 1))
        for n in range(nTrain):
            stimA, stimB = makeSignal(t, fPre, fNMDA, dt=dt, seed=n)
            data = go(NPre=NPre,
                      NBias=NBias,
                      N=N,
                      t=t,
                      dt=dt,
                      DA=DATrain,
                      stimA=stimA,
                      stimB=stimB,
                      stage=0)
            spikesInpt[n] = data['preA']
            targetsInpt[n] = fPre.filt(data['inptA'], dt=dt)
        dPre, X, Y, error = decodeNoF(spikesInpt,
                                      targetsInpt,
                                      nTrain,
                                      fPre,
                                      dt=dt,
                                      reg=reg)
        np.savez(
            "data/integrateNMDAbias2.npz",
            dPre=dPre,
        )
        times = np.arange(0, t * nTrain, dt)
        plotState(times, X, Y, error, "integrateNMDAbias2", "pre", t * nTrain)

    if 1 in load:
        ePreFdfw = np.load(file)['ePreFdfw']
        ePreBias = np.load(file)['ePreBias']
    else:
        print("encoders for [preA, preC] to [fdfw, bias]")
        ePreFdfw = np.zeros((NPre, N, 1))
        ePreBias = np.zeros((NPre, NBias, 1))
        for n in range(nEnc):
            stimA, stimB = makeSignal(t, fPre, fNMDA, dt=dt, seed=n)
            data = go(NPre=NPre,
                      NBias=NBias,
                      N=N,
                      t=t,
                      dt=dt,
                      DA=DATrain,
                      stimA=stimA,
                      stimB=stimB,
                      dPre=dPre,
                      ePreFdfw=ePreFdfw,
                      ePreBias=ePreBias,
                      stage=1)
            ePreFdfw = data['ePreFdfw']
            ePreBias = data['ePreBias']
            plotActivity(t, dt, fS, data['times'], data['fdfw'],
                         data['tarFdfw'], "integrateNMDAbias2", "pre_fdfw")
            plotActivity(t, dt, fS, data['times'], data['bias'],
                         data['tarBias'], "integrateNMDAbias2", "pre_bias")
            np.savez("data/integrateNMDAbias2.npz",
                     dPre=dPre,
                     ePreFdfw=ePreFdfw,
                     ePreBias=ePreBias)

    if 2 in load:
        dFdfw = np.load(file)['dFdfw']
        dBias = np.load(file)['dBias']
    else:
        print('readout decoders for fdfw and bias')
        spikesFdfw = np.zeros((nTrain, int(t / dt), N))
        spikesBias = np.zeros((nTrain, int(t / dt), NBias))
        targetsFdfw = np.zeros((nTrain, int(t / dt), 1))
        targetsBias = np.zeros((nTrain, int(t / dt), 1))
        for n in range(nTrain):
            stimA, _ = makeSignal(t, fPre, fNMDA, dt=dt, seed=n)
            data = go(NPre=NPre,
                      NBias=NBias,
                      N=N,
                      t=t,
                      dt=dt,
                      DA=DATrain,
                      stimA=stimA,
                      dPre=dPre,
                      ePreFdfw=ePreFdfw,
                      ePreBias=ePreBias,
                      stage=2)
            spikesFdfw[n] = data['fdfw']
            spikesBias[n] = data['bias']
            targetsFdfw[n] = fNMDA.filt(Tff * fPre.filt(data['inptA'], dt=dt))
            targetsBias[n] = fGABA.filt(fPre.filt(data['inptC'], dt=dt))
        dFdfw, X1, Y1, error1 = decodeNoF(spikesFdfw,
                                          targetsFdfw,
                                          nTrain,
                                          fNMDA,
                                          dt=dt,
                                          reg=reg)
        dBias, X2, Y2, error2 = decodeNoF(spikesBias,
                                          targetsBias,
                                          nTrain,
                                          fGABA,
                                          dt=dt,
                                          reg=reg)
        np.savez("data/integrateNMDAbias2.npz",
                 dPre=dPre,
                 ePreFdfw=ePreFdfw,
                 ePreBias=ePreBias,
                 dFdfw=dFdfw,
                 dBias=dBias)
        times = np.arange(0, t * nTrain, dt)
        plotState(times, X1, Y1, error1, "integrateNMDAbias2", "fdfw",
                  t * nTrain)
        plotState(times, X2, Y2, error2, "integrateNMDAbias2", "bias",
                  t * nTrain)

    if 3 in load:
        eBiasEns = np.load(file)['eBiasEns']
    else:
        print("encoders for [bias] to ens")
        eBiasEns = np.zeros((NBias, N, 1))
        for n in range(nEnc):
            stimA, stimB = makeSignal(t, fPre, fNMDA, dt=dt, seed=n)
            data = go(NPre=NPre,
                      NBias=NBias,
                      N=N,
                      t=t,
                      dt=dt,
                      DA=DATrain,
                      stimA=stimA,
                      dPre=dPre,
                      dBias=dBias,
                      ePreBias=ePreBias,
                      eBiasEns=eBiasEns,
                      stage=3)
            eBiasEns = data['eBiasEns']
            plotActivity(t, dt, fS, data['times'], data['ens'], data['tarEns'],
                         "integrateNMDAbias2", "bias_ens")
        np.savez(
            "data/integrateNMDAbias2.npz",
            dPre=dPre,
            ePreFdfw=ePreFdfw,
            ePreBias=ePreBias,
            dFdfw=dFdfw,
            dBias=dBias,
            eBiasEns=eBiasEns,
        )

    if 4 in load:
        ePreEns = np.load(file)['ePreEns']
    else:
        print("encoders for [preB] to [ens]")
        ePreEns = np.zeros((NPre, N, 1))
        for n in range(nEnc):
            stimA, stimB = makeSignal(t, fPre, fNMDA, dt=dt, seed=n)
            data = go(NPre=NPre,
                      NBias=NBias,
                      N=N,
                      t=t,
                      dt=dt,
                      DA=DATrain,
                      stimA=stimA,
                      stimB=stimB,
                      dPre=dPre,
                      dBias=dBias,
                      ePreBias=ePreBias,
                      eBiasEns=eBiasEns,
                      ePreEns=ePreEns,
                      stage=4)
            ePreEns = data['ePreEns']
            plotActivity(t, dt, fS, data['times'], data['ens'], data['tarEns'],
                         "integrateNMDAbias2", "pre_ens")
            np.savez(
                "data/integrateNMDAbias2.npz",
                dPre=dPre,
                ePreFdfw=ePreFdfw,
                ePreBias=ePreBias,
                dFdfw=dFdfw,
                dBias=dBias,
                eBiasEns=eBiasEns,
                ePreEns=ePreEns,
            )

    if 5 in load:
        eFdfwEns = np.load(file)['eFdfwEns']
    else:
        print("encoders for [fdfw] to ens")
        eFdfwEns = np.zeros((N, N, 1))
        for n in range(nEnc):
            # stimA, stimB = makeSignal(t, fPre, fNMDA, dt=dt, seed=n)
            stimA, stimB = makeCarrier(t, fPre, fNMDA, dt=dt, seed=n)
            data = go(NPre=NPre,
                      NBias=NBias,
                      N=N,
                      t=t,
                      dt=dt,
                      DA=DATrain,
                      stimA=stimA,
                      Tff=Tff,
                      dPre=dPre,
                      dFdfw=dFdfw,
                      dBias=dBias,
                      ePreFdfw=ePreFdfw,
                      ePreEns=ePreEns,
                      eFdfwEns=eFdfwEns,
                      eBiasEns=eBiasEns,
                      stage=5)
            eFdfwEns = data['eFdfwEns']
            plotActivity(t, dt, fS, data['times'], data['ens'], data['tarEns'],
                         "integrateNMDAbias2", "fdfw_ens")
        np.savez(
            "data/integrateNMDAbias2.npz",
            dPre=dPre,
            ePreFdfw=ePreFdfw,
            ePreEns=ePreEns,
            ePreBias=ePreBias,
            dFdfw=dFdfw,
            dBias=dBias,
            eFdfwEns=eFdfwEns,
            eBiasEns=eBiasEns,
        )

    if 6 in load:
        dEns = np.load(file)['dEns']
    else:
        print('readout decoders for ens')
        spikes = np.zeros((nTrain, int(t / dt), N))
        targets = np.zeros((nTrain, int(t / dt), 1))
        for n in range(nTrain):
            # stimA, stimB = makeSignal(t, fPre, fNMDA, dt=dt, seed=n)
            stimA, stimB = makeCarrier(t, fPre, fNMDA, dt=dt, seed=n)
            data = go(NPre=NPre,
                      NBias=NBias,
                      N=N,
                      t=t,
                      dt=dt,
                      DA=DATrain,
                      stimA=stimA,
                      stimB=stimB,
                      dPre=dPre,
                      dFdfw=dFdfw,
                      dBias=dBias,
                      ePreFdfw=ePreFdfw,
                      ePreEns=ePreEns,
                      eFdfwEns=eFdfwEns,
                      eBiasEns=eBiasEns,
                      stage=6)
            spikes[n] = data['ens']
            targets[n] = fNMDA.filt(fPre.filt(data['inptB']))
        dEns, X, Y, error = decodeNoF(spikes,
                                      targets,
                                      nTrain,
                                      fNMDA,
                                      dt=dt,
                                      reg=reg)
        np.savez("data/integrateNMDAbias2.npz",
                 dPre=dPre,
                 ePreFdfw=ePreFdfw,
                 ePreEns=ePreEns,
                 ePreBias=ePreBias,
                 dFdfw=dFdfw,
                 dBias=dBias,
                 eFdfwEns=eFdfwEns,
                 eBiasEns=eBiasEns,
                 dEns=dEns)
        times = np.arange(0, t * nTrain, dt)
        plotState(times, X, Y, error, "integrateNMDAbias2", "ens", t * nTrain)

    if 7 in load:
        eEnsEns = np.load(file)['eEnsEns']
    else:
        print("encoders from ens to ens")
        eEnsEns = np.zeros((N, N, 1))
        for n in range(nEnc):
            # stimA, stimB = makeSignal(t, fPre, fNMDA, dt=dt, seed=n)
            stimA, stimB = makeCarrier(t, fPre, fNMDA, dt=dt, seed=n)
            data = go(NPre=NPre,
                      NBias=NBias,
                      N=N,
                      t=t,
                      dt=dt,
                      DA=DATrain,
                      stimA=stimA,
                      stimB=stimB,
                      dPre=dPre,
                      dFdfw=dFdfw,
                      dBias=dBias,
                      dEns=dEns,
                      ePreFdfw=ePreFdfw,
                      ePreEns=ePreEns,
                      eFdfwEns=eFdfwEns,
                      eBiasEns=eBiasEns,
                      eEnsEns=eEnsEns,
                      stage=7)
            eEnsEns = data['eEnsEns']
            plotActivity(t, dt, fS, data['times'], data['ens'], data['tarEns'],
                         "integrateNMDAbias2", "ens_ens")
        np.savez("data/integrateNMDAbias2.npz",
                 dPre=dPre,
                 ePreFdfw=ePreFdfw,
                 ePreEns=ePreEns,
                 ePreBias=ePreBias,
                 dFdfw=dFdfw,
                 dBias=dBias,
                 eFdfwEns=eFdfwEns,
                 eBiasEns=eBiasEns,
                 dEns=dEns,
                 eEnsEns=eEnsEns)

    print("testing")
    vals = np.linspace(-1, 1, nTest)
    for test in range(nTest):
        stimA, stimB = makeSignal(t, fPre, fNMDA, dt=dt, seed=200 + test)
        data = go(NPre=NPre,
                  NBias=NBias,
                  N=N,
                  t=t,
                  dt=dt,
                  DA=DATest,
                  stimA=stimA,
                  stimB=stimB,
                  Tff=Tff,
                  dPre=dPre,
                  dFdfw=dFdfw,
                  dBias=dBias,
                  dEns=dEns,
                  ePreFdfw=ePreFdfw,
                  ePreEns=ePreEns,
                  eFdfwEns=eFdfwEns,
                  eBiasEns=eBiasEns,
                  eEnsEns=eEnsEns,
                  stage=8)
        aFdfw = fNMDA.filt(data['fdfw'], dt=dt)
        aBio = fNMDA.filt(data['ens'], dt=dt)
        xhatFdfw = np.dot(aFdfw, dFdfw)
        xhatEns = np.dot(aBio, dEns)
        xFdfw = fNMDA.filt(fPre.filt(data['inptA'], dt=dt), dt=dt)
        xEns = fNMDA.filt(fPre.filt(data['inptB'], dt=dt), dt=dt)
        errorBio = rmse(xhatEns, xEns)
        fig, ax = plt.subplots()
        ax.plot(data['times'], xhatFdfw, alpha=0.5, label="fdfw")
        ax.plot(data['times'], Tff * xFdfw, alpha=0.5, label="input")
        ax.plot(data['times'], xEns, label="target")
        ax.plot(data['times'],
                xhatEns,
                alpha=0.5,
                label="ens, rmse=%.3f" % errorBio)
        ax.set(xlabel='time',
               ylabel='state',
               xlim=((0, t)),
               ylim=((-1.5, 1.5)))
        ax.legend(loc='upper left')
        sns.despine()
        fig.savefig("plots/integrateNMDAbias2_testWhite%s.pdf" % test)
        plt.close('all')
Exemplo n.º 18
0
def run(n_neurons=30, t=10, t_test=10, tDA=5, dt=0.001, n_encodes=10, n_tests=1, reg=0, f_NMDA=DoubleExp(1.06e-2, 2.85e-1), f_GABA=DoubleExp(5e-4, 1.5e-3), f_AMPA=DoubleExp(5.5e-4, 2.2e-3), f_s=DoubleExp(1e-2, 2e-1), load_w=None, load_fd=None, data_file="data/dale2.npz"):


    # Stage 1
    DA = lambda t: 0
    if load_w:
        ePExc = np.load(data_file)['ePExc']
        wPExc = np.load(data_file)['wPExc']
        eIExc = np.load(data_file)['eIExc']
        wIExc = np.load(data_file)['wIExc']
        ePInh = np.load(data_file)['ePInh']
        wPInh = np.load(data_file)['wPInh']
        eIInh = np.load(data_file)['eIInh']
        wIInh = np.load(data_file)['wIInh']
    else:
        ePExc = None
        eIExc = None
        ePInh = None
        eIInh = None
        print('Optimizing pre-ens encoders')
        for nenc in range(n_encodes):
            print("encoding trial %s"%nenc)
            s = make_signal(t=t, dt=dt, f=f_AMPA, seed=nenc)
            stim = lambda t: s[int(t/dt)]
            data = go(ePExc=ePExc, eIExc=eIExc, ePInh=ePInh, eIInh=eIInh,
                f_NMDA=f_NMDA, f_AMPA=f_AMPA, f_GABA=f_GABA, f_s=f_s,
                n_neurons=n_neurons, t=t, stim=stim, stage=1)
            ePExc = data['ePExc']
            wPExc = data['wPExc']
            eIExc = data['eIExc']
            wIExc = data['wIExc']
            ePInh = data['ePInh']
            wPInh = data['wPInh']
            eIInh = data['eIInh']
            wIInh = data['wIInh']
            np.savez('data/dale2.npz', ePExc=ePExc, wPExc=wPExc, eIExc=eIExc, wIExc=wIExc, ePInh=ePInh, wPInh=wPInh, eIInh=eIInh, wIInh=wIInh)
            aSupv = f_s.filt(data['supv'])
            aP = f_s.filt(data['P'])
            aI = f_s.filt(data['I'])
            for n in range(n_neurons):
                fig, ax = plt.subplots(1, 1)
                ax.plot(data['times'], aSupv[:,n], alpha=0.5, label='supv')
                ax.plot(data['times'], aP[:,n], alpha=0.5, label='P')
                ax.plot(data['times'], aI[:,n], alpha=0.5, label='I')
                ax.set(ylim=((0, 40)))
                plt.legend()
                plt.savefig('plots/tuning/dale_eFF_%s.pdf'%n)
                plt.close('all')

    # Stage 2
    if load_fd:
        dNMDA = np.load(data_file)['dNMDA']
        dAMPA = np.load(data_file)['dAMPA']
        dGABA = np.load(data_file)['dGABA']
    else:
        print("Optimizing decoders")
        s = make_signal(t=t, dt=dt, f=f_AMPA, seed=0)
        stim = lambda t: s[int(t/dt)]
        data = go(
            wPExc=wPExc, wIExc=wIExc, wPInh=wPInh, wIInh=wIInh,
            f_NMDA=f_NMDA, f_AMPA=f_AMPA, f_GABA=f_GABA, f_s=f_s,
            n_neurons=n_neurons, t=t, stim=stim, stage=2)
        aNMDA = f_NMDA.filt(data['P'])
        xFB = f_NMDA.filt(data['u'])
        dNMDA, _ = nengo.solvers.LstsqL2(reg=1e-2)(aNMDA, xFB)
        # positive w fron mixed dFB enforced by learning node in stage 3
        aAMPA = f_AMPA.filt(data['P'])
        aGABA = f_GABA.filt(data['I'])
        xOut = data['u']
        aOut = np.hstack((aAMPA, -aGABA))
        dBoth, _ = nnls(aOut, np.ravel(xOut))
        dAMPA = dBoth[:n_neurons].reshape((n_neurons, 1))
        dGABA = -dBoth[n_neurons:].reshape((n_neurons, 1))
        np.savez('data/dale2.npz', ePExc=ePExc, wPExc=wPExc, eIExc=eIExc, wIExc=wIExc, ePInh=ePInh, wPInh=wPInh, eIInh=eIInh, wIInh=wIInh, dNMDA=dNMDA, dAMPA=dAMPA, dGABA=dGABA)
        xhatFB = np.dot(aNMDA, dNMDA)
        xhatOut = np.dot(aAMPA, dAMPA) + np.dot(aGABA, dGABA)
        xhatFB_rmse = rmse(np.ravel(xhatFB), np.ravel(xFB))
        xhatOut_rmse = rmse(np.ravel(xhatOut), np.ravel(xOut))
        fig, ax = plt.subplots()
        ax.plot(data['times'], xFB, alpha=0.5, linestyle="--", label='xFB')
        ax.plot(data['times'], xhatFB, label='xhatFB, rmse=%.3f' %xhatFB_rmse)
#         ax.plot(data['times'], xOut, alpha=0.5, linestyle="--", label='xOut')
#         ax.plot(data['times'], xhatOut, label='xhatOut, rmse=%.3f' %xhatOut_rmse)
        ax.set(xlabel='time (s)', ylabel=r'$\mathbf{x}$', title="train decoders")
        plt.legend(loc='upper right')
        plt.savefig("plots/dale2_train_decoders.pdf")

    # Stage 3
#     dGABAFb = np.random.RandomState(seed=0).uniform(-1e-4, 0, size=(n_neurons, 1))
    dGABAFb = -3e-4*np.ones((n_neurons, 1))
    if load_w:
        ePP = np.load(data_file)['ePP']
        wPP = np.load(data_file)['wPP']
        ePI = np.load(data_file)['ePI']
        wPI = np.load(data_file)['wPI']
        eIP = np.load(data_file)['eIP']
        wIP = np.load(data_file)['wIP']
        eII = np.load(data_file)['eII']
        wII = np.load(data_file)['wII']
    else:
        print('Optimizing ePP, ePI, eIP, eII')
        ePP = None
        ePI = None
        eIP = None
        eII = None
        DA = lambda t: 0 if t<tDA else 0.5
        for nenc in range(n_encodes):
            print("encoding trial %s"%nenc)
            s = make_signal(t=t, dt=dt, f=f_NMDA, seed=nenc)
            stim = lambda t: s[int(t/dt)]
            data = go(
                wPExc=wPExc, wIExc=wIExc, wPInh=wPInh, wIInh=wIInh,
                ePP=ePP, ePI=ePI, eIP=eIP, eII=eII,
                dNMDA=dNMDA, dAMPA=dAMPA, dGABA=dGABAFb,
                f_NMDA=f_NMDA, f_AMPA=f_AMPA, f_GABA=f_GABA, f_s=f_s,
                n_neurons=n_neurons, t=t, stim=stim, DA=DA, stage=3)
            ePP = data['ePP']
            wPP = data['wPP']
            ePI = data['ePI']
            wPI = data['wPI']
            eIP = data['eIP']
            wIP = data['wIP']
            eII = data['eII']
            wII = data['wII'] 
            np.savez('data/dale2.npz', ePExc=ePExc, wPExc=wPExc, eIExc=eIExc, wIExc=wIExc, ePInh=ePInh, wPInh=wPInh, eIInh=eIInh, wIInh=wIInh, dNMDA=dNMDA, dAMPA=dAMPA, dGABA=dGABA, ePP=ePP, wPP=wPP, ePI=ePI, wPI=wPI, eIP=eIP, wIP=wIP, eII=eII, wII=wII)
            aP = f_s.filt(data['P'])
            aI = f_s.filt(data['I'])
            aBuffer = f_s.filt(data['buffer'])
            for n in range(n_neurons):
                fig, (ax, ax2) = plt.subplots(nrows=1, ncols=2, sharey=True)
                ax.plot(data['times'], aBuffer[:,n], alpha=0.5, label='buffer')
                ax.plot(data['times'], aP[:,n], alpha=0.5, label='P')
                ax2.plot(data['times'], aBuffer[:,n], alpha=0.5, label='buffer')
                ax2.plot(data['times'], aI[:,n], alpha=0.5, label='I')
                ax.set(ylim=((0, 40)))
                ax.legend()
                ax2.legend()
                plt.savefig('plots/tuning/dale_eFB_%s.pdf'%n)
                plt.close('all')
            # confirm gated integrator (which generates targets) is working in state-space
            fig, ax = plt.subplots()
            ax.plot(data['times'], data['uDA'], label="DA")
            ax.plot(data['times'], data['gate_x'], label="gate")
            ax.plot(data['times'], data['buffer_x'], label="buffer")
            ax.plot(data['times'], data['fdbk_x'], label="fdbk")
            ax.legend()
            fig.savefig("plots/dale2_train_gatedIntegrator.pdf")

            
    # Stage 4
    checkWeights(wPExc, wIExc, wPInh, wIInh, wPP, wPI, wIP, wII, dAMPA, dGABA)
    DA = lambda t: 0 if t < tDA else 1.0
    dGABAFb = np.random.RandomState(seed=0).uniform(-1e-4, 0, size=(n_neurons, 1))
    for test in range(n_tests):
        print("Test %s"%test)
        s = make_signal(t=t, dt=dt, f=f_NMDA, seed=test)
        stim = lambda t: s[int(t/dt)]
        data = go(
            wPExc=wPExc, wIExc=wIExc, wPInh=wPInh, wIInh=wIInh,
            wPP=wPP, wPI=wPI, wIP=wIP, wII=wII,
            dNMDA=dNMDA, dAMPA=dAMPA, dGABA=dGABAFb,
            f_NMDA=f_NMDA, f_AMPA=f_AMPA, f_GABA=f_GABA, f_s=f_s,
            n_neurons=n_neurons, t=t, stim=stim, DA=DA, stage=4)
        aNMDA = f_NMDA.filt(data['P'])
        aAMPA = f_AMPA.filt(data['P'])
        aGABA = f_GABA.filt(data['I'])
        u = data['u']
        da = data['uDA']
        gate = data['gate_x']
        buffer = data['buffer_x']
        fdbk = data['fdbk_x']
        xhat = np.dot(aNMDA, dNMDA)
        xhat_rmse = rmse(xhat, buffer)
        fig, ax = plt.subplots()
        ax.plot(data['times'], u, linestyle="--", label='u')
        ax.plot(data['times'], da, linestyle="--", label='DA')
        ax.plot(data['times'], gate, label="gate")
        ax.plot(data['times'], buffer, label="buffer")
        ax.plot(data['times'], fdbk, label="fdbk")
        ax.plot(data['times'], xhat, label='xhat, rmse=%.3f' %xhat_rmse)
        ax.set(xlabel='time (s)', ylabel=r'$\mathbf{x}$', title="test")
        plt.legend(loc='upper right')
        plt.savefig("plots/dale2_test_DA1_%s.pdf"%test)
Exemplo n.º 19
0
def run(n_neurons=100,
        t=20,
        t_test=10,
        t_enc=30,
        dt=0.001,
        f=DoubleExp(1e-2, 2e-1),
        penalty=0,
        reg=1e-1,
        freq=1,
        tt=5.0,
        tt_test=5.0,
        neuron_type=LIF(),
        load_fd=False,
        load_w=None,
        supervised=False):

    d_ens = np.zeros((n_neurons, 2))
    f_ens = f
    w_ff = None
    w_fb = None
    e_ff = None
    e_fb = None
    f_smooth = DoubleExp(1e-2, 2e-1)
    print('Neuron Type: %s' % neuron_type)

    if isinstance(neuron_type, DurstewitzNeuron):
        if load_w:
            w_ff = np.load(load_w)['w_ff']
            e_ff = np.load(load_w)['e_ff']
        else:
            print('optimizing encoders from pre to ens')
            data = go(d_ens,
                      f_ens,
                      n_neurons=n_neurons,
                      t=t_enc + tt,
                      f=f,
                      dt=dt,
                      neuron_type=neuron_type,
                      w_ff=w_ff,
                      e_ff=e_ff,
                      L_ff=True)
            w_ff = data['w_ff']
            e_ff = data['e_ff']
            np.savez('data/oscillate_w.npz', w_ff=w_ff, e_ff=e_ff)

            fig, ax = plt.subplots()
            sns.distplot(np.ravel(w_ff), ax=ax, kde=False)
            ax.set(xlabel='weights', ylabel='frequency')
            plt.savefig("plots/tuning/oscillate_%s_w_ff.pdf" % (neuron_type))

            a_ens = f_smooth.filt(data['ens'], dt=dt)
            a_supv = f_smooth.filt(data['supv'], dt=dt)
            for n in range(n_neurons):
                fig, ax = plt.subplots(1, 1)
                ax.plot(data['times'], a_supv[:, n], alpha=0.5, label='supv')
                ax.plot(data['times'], a_ens[:, n], alpha=0.5, label='ens')
                ax.set(ylim=((0, 40)))
                plt.legend()
                plt.savefig('plots/tuning/oscillate_pre_ens_activity_%s.pdf' %
                            (n))
                plt.close('all')

    if load_fd:
        load = np.load(load_fd)
        d_ens = load['d_ens']
        taus_ens = load['taus_ens']
        f_ens = DoubleExp(taus_ens[0], taus_ens[1])
    else:
        print('gathering filter/decoder training data for ens')
        data = go(d_ens,
                  f_ens,
                  n_neurons=n_neurons,
                  t=t + tt,
                  f=f,
                  dt=dt,
                  neuron_type=neuron_type,
                  w_ff=w_ff,
                  L_fd=True)
        trans = int(tt / dt)
        d_ens, f_ens, taus_ens = df_opt(data['u'][trans:],
                                        data['ens'][trans:],
                                        f,
                                        dt=dt,
                                        name='oscillate_%s' % neuron_type,
                                        reg=reg,
                                        penalty=penalty)
        np.savez('data/oscillate_%s_fd.npz' % neuron_type,
                 d_ens=d_ens,
                 taus_ens=taus_ens)

        times = np.arange(0, 1, 0.0001)
        fig, ax = plt.subplots()
        ax.plot(times,
                f.impulse(len(times), dt=0.0001),
                label=r"$f^x, \tau_1=%.3f, \tau_2=%.3f$" %
                (-1. / f.poles[0], -1. / f.poles[1]))
        ax.plot(times,
                f_ens.impulse(len(times), dt=0.0001),
                label=r"$f^{ens}, \tau_1=%.3f, \tau_2=%.3f, d: %s/%s$" %
                (-1. / f_ens.poles[0], -1. / f_ens.poles[1],
                 np.count_nonzero(d_ens), n_neurons))
        ax.set(xlabel='time (seconds)',
               ylabel='impulse response',
               ylim=((0, 10)))
        ax.legend(loc='upper right')
        plt.tight_layout()
        plt.savefig("plots/oscillate_%s_filters_ens.pdf" % neuron_type)

        a_ens = f_ens.filt(data['ens'], dt=dt)
        x = f.filt(data['u'], dt=dt)
        xhat_ens = np.dot(a_ens, d_ens)
        rmse_ens = rmse(xhat_ens, x)
        fig, ax = plt.subplots()
        ax.plot(data['times'], x, linestyle="--", label='x')
        ax.plot(data['times'], xhat_ens, label='ens, rmse=%.3f' % rmse_ens)
        ax.set(xlabel='time (s)', ylabel=r'$\mathbf{x}$', title="pre_ens")
        plt.legend(loc='upper right')
        plt.savefig("plots/oscillate_%s_pre_ens_train.pdf" % neuron_type)

    if isinstance(neuron_type, DurstewitzNeuron):
        if load_w:
            w_fb = np.load(load_w)['w_fb']
            e_fb = np.load(load_w)['e_fb']
        else:
            print('optimizing encoders from supv to ens')
            data = go(d_ens,
                      f_ens,
                      n_neurons=n_neurons,
                      t=t_enc + tt,
                      f=f,
                      dt=dt,
                      neuron_type=neuron_type,
                      w_ff=w_ff,
                      w_fb=w_fb,
                      e_fb=e_fb,
                      L_fb=True)
            w_fb = data['w_fb']
            e_fb = data['e_fb']
            np.savez('data/oscillate_w.npz',
                     w_ff=w_ff,
                     e_ff=e_ff,
                     w_fb=w_fb,
                     e_fb=e_fb)

            fig, ax = plt.subplots()
            sns.distplot(np.ravel(w_fb), ax=ax, kde=False)
            ax.set(xlabel='weights', ylabel='frequency')
            plt.savefig("plots/tuning/oscillate_%s_w_fb.pdf" % (neuron_type))

            a_ens = f_smooth.filt(data['ens'], dt=dt)
            a_supv = f_smooth.filt(data['supv'], dt=dt)
            #             a_supv2 = f_smooth.filt(data['supv2'], dt=dt)
            for n in range(n_neurons):
                fig, ax = plt.subplots(1, 1)
                ax.plot(data['times'], a_supv[:, n], alpha=0.5, label='supv')
                #                 ax.plot(data['times'], a_supv2[:,n], alpha=0.5, label='supv2')
                ax.plot(data['times'], a_ens[:, n], alpha=0.5, label='ens')
                ax.set(ylim=((0, 40)))
                plt.legend()
                plt.savefig('plots/tuning/oscillate_supv_ens_activity_%s.pdf' %
                            (n))
                plt.close('all')

    print("Testing")
    if supervised:
        data = go(d_ens,
                  f_ens,
                  n_neurons=n_neurons,
                  t=t_test + tt_test,
                  f=f,
                  dt=dt,
                  neuron_type=neuron_type,
                  w_ff=w_ff,
                  w_fb=w_fb,
                  supervised=True)

        a_ens = f_ens.filt(data['ens'], dt=dt)
        a_supv = f_ens.filt(data['supv'], dt=dt)
        #         a_supv2 = f_ens.filt(data['supv2'], dt=dt)
        xhat_ens_0 = np.dot(a_ens, d_ens)[:, 0]
        xhat_ens_1 = np.dot(a_ens, d_ens)[:, 1]
        xhat_supv_0 = np.dot(a_supv, d_ens)[:, 0]
        xhat_supv_1 = np.dot(a_supv, d_ens)[:, 1]
        #         xhat_supv2_0 = np.dot(a_supv2, d_ens)[:,0]
        #         xhat_supv2_1 = np.dot(a_supv2, d_ens)[:,1]
        x_0 = f.filt(data['u'], dt=dt)[:, 0]
        x_1 = f.filt(data['u'], dt=dt)[:, 1]
        x2_0 = f.filt(x_0, dt=dt)
        x2_1 = f.filt(x_1, dt=dt)
        times = data['times']

        fig, ax = plt.subplots()
        ax.plot(times, x_0, linestyle="--", label='x_0')
        ax.plot(times, x2_0, linestyle="--", label='x2_0')
        ax.plot(times, xhat_supv_0, label='supv')
        ax.plot(times, xhat_ens_0, label='ens')
        #         ax.plot(times, xhat_supv2_0, label='supv2')
        ax.set(xlim=((0, t_test)),
               ylim=((-1, 1)),
               xlabel='time (s)',
               ylabel=r'$\mathbf{x}$')
        plt.legend(loc='upper right')
        plt.savefig("plots/oscillate_%s_supervised_0.pdf" % neuron_type)

        fig, ax = plt.subplots()
        ax.plot(times, x_1, linestyle="--", label='x_1')
        ax.plot(times, x2_1, linestyle="--", label='x2_1')
        ax.plot(times, xhat_supv_1, label='supv')
        ax.plot(times, xhat_ens_1, label='ens')
        #         ax.plot(times, xhat_supv2_1, label='supv2')
        ax.set(xlim=((0, t_test)),
               ylim=((-1, 1)),
               xlabel='time (s)',
               ylabel=r'$\mathbf{x}$')
        plt.legend(loc='upper right')
        plt.savefig("plots/oscillate_%s_supervised_1.pdf" % neuron_type)

    else:
        data = go(d_ens,
                  f_ens,
                  n_neurons=n_neurons,
                  t=t_test + tt_test,
                  f=f,
                  dt=dt,
                  neuron_type=neuron_type,
                  w_ff=w_ff,
                  w_fb=w_fb)

        a_ens = f_ens.filt(data['ens'], dt=dt)
        xhat_ens_0 = np.dot(a_ens, d_ens)[:, 0]
        xhat_ens_1 = np.dot(a_ens, d_ens)[:, 1]
        x_0 = f.filt(data['u'], dt=dt)[:, 0]
        x_1 = f.filt(data['u'], dt=dt)[:, 1]
        x2_0 = f.filt(x_0, dt=dt)
        x2_1 = f.filt(x_1, dt=dt)
        times = data['times']

        #         fig, ax = plt.subplots()
        #         ax.plot(times, x_0, linestyle="--", label='x0')
        # #         ax.plot(times, sinusoid_0, label='best fit sinusoid_0')
        #         ax.plot(times, xhat_ens_0, label='ens')
        #         ax.set(xlim=((0, t_test)), ylim=((-1, 1)), xlabel='time (s)', ylabel=r'$\mathbf{x}$')
        #         plt.legend(loc='upper right')
        #         plt.savefig("plots/oscillate_%s_test_0.pdf"%neuron_type)
        #         fig, ax = plt.subplots()
        #         ax.plot(times, x_1, linestyle="--", label='x1')
        # #         ax.plot(times, sinusoid_1, label='best fit sinusoid_1')
        #         ax.plot(times, xhat_ens_1, label='ens')
        #         ax.set(xlim=((0, t_test)), ylim=((-1, 1)), xlabel='time (s)', ylabel=r'$\mathbf{x}$')
        #         plt.legend(loc='upper right')
        #         plt.savefig("plots/oscillate_%s_test_1.pdf"%neuron_type)

        # curve fit to a sinusoid of arbitrary frequency, phase, magnitude
        print('Curve fitting')
        trans = int(tt_test / dt)
        step = int(0.001 / dt)

        def sinusoid(t, freq, phase, mag, dt=dt):  # mag
            return f.filt(mag *
                          np.sin(t * 2 * np.pi * freq + 2 * np.pi * phase),
                          dt=dt)

        p0 = [1, 0, 1]
        param_0, _ = curve_fit(sinusoid,
                               times[trans:],
                               xhat_ens_0[trans:],
                               p0=p0)
        param_1, _ = curve_fit(sinusoid,
                               times[trans:],
                               xhat_ens_1[trans:],
                               p0=p0)
        print('param0', param_0)
        print('param1', param_1)
        sinusoid_0 = sinusoid(times, param_0[0], param_0[1], param_0[2])
        sinusoid_1 = sinusoid(times, param_1[0], param_1[1], param_1[2])

        # error is rmse of xhat and best fit sinusoid times freq error of best fit sinusoid to x
        freq_error_0 = np.abs(freq - param_0[1])
        freq_error_1 = np.abs(freq - param_1[1])
        rmse_0 = rmse(xhat_ens_0[trans::step], sinusoid_0[trans::step])
        rmse_1 = rmse(xhat_ens_1[trans::step], sinusoid_1[trans::step])
        scaled_rmse_0 = (1 + freq_error_0) * rmse_0
        scaled_rmse_1 = (1 + freq_error_1) * rmse_1

        fig, ax = plt.subplots()
        ax.plot(times, x_0, linestyle="--", label='x0')
        ax.plot(times, sinusoid_0, label='best fit sinusoid_0')
        ax.plot(times,
                xhat_ens_0,
                label='ens, scaled rmse=%.3f' % scaled_rmse_0)
        ax.axvline(tt_test, label=r"$t_{transient}$")
        ax.set(ylim=((-1, 1)), xlabel='time (s)', ylabel=r'$\mathbf{x}$')
        plt.legend(loc='upper right')
        plt.savefig("plots/oscillate_%s_test_0.pdf" % neuron_type)

        fig, ax = plt.subplots()
        ax.plot(times, x_1, linestyle="--", label='x1')
        ax.plot(times, sinusoid_1, label='best fit sinusoid_1')
        ax.plot(times,
                xhat_ens_1,
                label='ens, scaled rmse=%.3f' % scaled_rmse_1)
        ax.axvline(tt_test, label=r"$t_{transient}$")
        ax.set(ylim=((-1, 1)), xlabel='time (s)', ylabel=r'$\mathbf{x}$')
        plt.legend(loc='upper right')
        plt.savefig("plots/oscillate_%s_test_1.pdf" % neuron_type)

        print('scaled rmses: ', scaled_rmse_0, scaled_rmse_1)
        mean = np.mean([scaled_rmse_0, scaled_rmse_1])
        fig, ax = plt.subplots()
        sns.barplot(data=np.array([mean]))
        ax.set(ylabel='Scaled RMSE', title="mean=%.3f" % mean)
        plt.xticks()
        plt.savefig("plots/oscillate_%s_scaled_rmse.pdf" % neuron_type)
        np.savez('data/oscillate_%s_results.npz' % neuron_type,
                 scaled_rmse_0=scaled_rmse_0,
                 scaled_rmse_1=scaled_rmse_1)
        return mean
Exemplo n.º 20
0
def run(NPre=100, N=100, t=10, nTrain=10, nTest=3, nEnc=10, dt=0.001, c=None,
        fPre=DoubleExp(1e-3, 1e-1), fNMDA=DoubleExp(10.6e-3, 285e-3), fGABA=DoubleExp(0.5e-3, 1.5e-3), fS=DoubleExp(1e-3, 1e-1),
        Tff=0.3, Tneg=-1, reg=1e-2, load=[], file=None, DATest=lambda t: 0):

    if not c: c = t
    if 0 in load:
        dPreA = np.load(file)['dPreA']
        dPreB = np.load(file)['dPreB']
        dPreC = np.load(file)['dPreC']
    else:
        print('readout decoders for pre[A,B,C]')
        spikesInptA = np.zeros((nTrain, int(t/dt), NPre))
        spikesInptB = np.zeros((nTrain, int(t/dt), NPre))
        spikesInptC = np.zeros((nTrain, int(t/dt), NPre))
        targetsInptA = np.zeros((nTrain, int(t/dt), 1))
        targetsInptB = np.zeros((nTrain, int(t/dt), 1))
        targetsInptC = np.zeros((nTrain, int(t/dt), 1))
        for n in range(nTrain):
            stimA, stimB = makeSignal(t, fPre, fNMDA, dt=dt, seed=n)
            data = go(
                NPre=NPre, N=N, t=t, dt=dt,
                stimA=stimA, stimB=stimB, stimC=stimB,
                stage=0)
            spikesInptA[n] = data['preA']
            spikesInptB[n] = data['preB']
            spikesInptC[n] = data['preC']
            targetsInptA[n] = fPre.filt(data['inptA'], dt=dt)
            targetsInptB[n] = fPre.filt(data['inptB'], dt=dt)
            targetsInptC[n] = fPre.filt(data['inptC'], dt=dt)
        dPreA, XA, YA, errorA = decodeNoF(spikesInptA, targetsInptA, nTrain, fPre, dt=dt, reg=reg)
        dPreB, XB, YB, errorB = decodeNoF(spikesInptB, targetsInptB, nTrain, fPre, dt=dt, reg=reg)
        dPreC, XC, YC, errorC = decodeNoF(spikesInptC, targetsInptC, nTrain, fPre, dt=dt, reg=reg)
        np.savez("data/integrateNMDA.npz",
            dPreA=dPreA, dPreB=dPreB, dPreC=dPreC)
        times = np.arange(0, t*nTrain, dt)
        plotState(times, XA, YA, errorA, "integrateNMDA", "preA", t*nTrain)
        plotState(times, XB, YB, errorB, "integrateNMDA", "preB", t*nTrain)
        plotState(times, XC, YC, errorC, "integrateNMDA", "preC", t*nTrain)

    if 1 in load:
        ePreA = np.load(file)['ePreA']
        ePreB = np.load(file)['ePreB']
        ePreC = np.load(file)['ePreC']
    else:
        print("encoders for pre[A,B,C] to [fdfw, ens, inh]")
        ePreA = np.zeros((NPre, N, 1))
        ePreB = np.zeros((NPre, N, 1))
        ePreC = np.zeros((NPre, N, 1))
        for n in range(nEnc):
            stimA, stimB = makeSignal(t, fPre, fNMDA, dt=dt, seed=n)
            data = go(
                NPre=NPre, N=N, t=t, dt=dt,
                stimA=stimA, stimB=stimB, stimC=stimB,
                dPreA=dPreA, dPreB=dPreB, dPreC=dPreC,
                ePreA=ePreA, ePreB=ePreB, ePreC=ePreC,
                stage=1)
            ePreA = data['ePreA']
            ePreB = data['ePreB']
            ePreC = data['ePreC']
            plotActivity(t, dt, fS, data['times'], data['fdfw'], data['tarFdfw'], "integrateNMDA", "preAFdfw")
            plotActivity(t, dt, fS, data['times'], data['ens'], data['tarEns'], "integrateNMDA", "preBEns")
            plotActivity(t, dt, fS, data['times'], data['inh'], data['tarInh'], "integrateNMDA", "preCInh")
            np.savez("data/integrateNMDA.npz",
                dPreA=dPreA, dPreB=dPreB, dPreC=dPreC,
                ePreA=ePreA, ePreB=ePreB, ePreC=ePreC)

    if 2 in load:
        dFdfw = np.load(file)['dFdfw']
        dInh = np.load(file)['dInh']
    else:
        print('readout decoders for fdfw and inh')
        spikesFdfw = np.zeros((nTrain, int(t/dt), N))
        spikesInh = np.zeros((nTrain, int(t/dt), N))
        targetsFdfw = np.zeros((nTrain, int(t/dt), 1))
        targetsInh = np.zeros((nTrain, int(t/dt), 1))
        for n in range(nTrain):
            stimA, stimB = makeSignal(t, fPre, fNMDA, dt=dt, seed=n)
            stimC = lambda x: 0.5+0.5*np.sin(2*np.pi*x/t)
            data = go(
                NPre=NPre, N=N, t=t, dt=dt,
                stimA=stimA, stimB=stimB, stimC=stimC, 
                dPreA=dPreA, dPreB=dPreB, dPreC=dPreC,
                ePreA=ePreA, ePreB=ePreB, ePreC=ePreC,
                stage=2)
            spikesFdfw[n] = data['fdfw']
            spikesInh[n] = data['inh']
            targetsFdfw[n] = fNMDA.filt(Tff*fPre.filt(data['inptA'], dt=dt))
            targetsInh[n] = fGABA.filt(fPre.filt(data['inptC'], dt=dt))
        dFdfw, X1, Y1, error1 = decodeNoF(spikesFdfw, targetsFdfw, nTrain, fNMDA, dt=dt, reg=reg)
        dInh, X2, Y2, error2 = decodeNoF(spikesInh, targetsInh, nTrain, fGABA, dt=dt, reg=reg)
        np.savez("data/integrateNMDA.npz",
            dPreA=dPreA, dPreB=dPreB, dPreC=dPreC, dFdfw=dFdfw, dInh=dInh,
            ePreA=ePreA, ePreB=ePreB, ePreC=ePreC)
        times = np.arange(0, t*nTrain, dt)
        plotState(times, X1, Y1, error1, "integrateNMDA", "fdfw", t*nTrain)
        plotState(times, X2, Y2, error2, "integrateNMDA", "inh", t*nTrain)
        
    if 3 in load:
        eFdfw = np.load(file)['eFdfw']
    else:
        print("encoders for fdfw-to-ens")
        eFdfw = np.zeros((N, N, 1))
        for n in range(nEnc):
            stimA, stimB = makeSignal(t, fPre, fNMDA, dt=dt, seed=n)
            data = go(
                NPre=NPre, N=N, t=t, dt=dt,
                stimA=stimA, stimB=stimB, Tff=Tff,
                dPreA=dPreA, dPreB=dPreB, dFdfw=dFdfw,
                ePreA=ePreA, ePreB=ePreB, eFdfw=eFdfw,
                stage=3)
            eFdfw = data['eFdfw']
            plotActivity(t, dt, fS, data['times'], data['ens'], data['tarEns'], "integrateNMDA", "fdfwEns")
        np.savez("data/integrateNMDA.npz",
            dPreA=dPreA, dPreB=dPreB, dPreC=dPreC, dFdfw=dFdfw, dInh=dInh, 
            ePreA=ePreA, ePreB=ePreB, ePreC=ePreC, eFdfw=eFdfw)

    if 4 in load:
        dBio = np.load(file)['dBio']
        dNeg = np.load(file)['dNeg']
        # dNeg = -np.array(dBio)
    else:
        print('readout decoders for ens')
        spikes = np.zeros((nTrain, int(t/dt), N))
        targets = np.zeros((nTrain, int(t/dt), 1))
        for n in range(nTrain):
            stimA, stimB = makeSignal(t, fPre, fNMDA, dt=dt, seed=n)
            data = go(
                NPre=NPre, N=N, t=t, dt=dt, stimA=stimA, stimB=stimB,
                dPreA=dPreA, dPreB=dPreB, dFdfw=dFdfw,
                ePreA=ePreA, ePreB=ePreB, eFdfw=eFdfw,
                stage=4)
            spikes[n] = data['ens']
            targets[n] = fNMDA.filt(fPre.filt(data['inptB']))
            # targets[n] = fNMDA.filt(
            #     fPre.filt(fNMDA.filt(data['inptB'])) +
            #     fNMDA.filt(Tff*fPre.filt(data['inptA'])))
        dBio, X, Y, error = decodeNoF(spikes, targets, nTrain, fNMDA, dt=dt, reg=reg)
        dNeg = -np.array(dBio)
        np.savez("data/integrateNMDA.npz",
            dPreA=dPreA, dPreB=dPreB, dPreC=dPreC, dFdfw=dFdfw, dInh=dInh, dBio=dBio, dNeg=dNeg,
            ePreA=ePreA, ePreB=ePreB, ePreC=ePreC, eFdfw=eFdfw)
        times = np.arange(0, t*nTrain, dt)
        plotState(times, X, Y, error, "integrateNMDA", "ens", t*nTrain)

    if 5 in load:
        eBio = np.load(file)['eBio']
    else:
        print("encoders from ens to ens")
        eBio = np.zeros((N, N, 1))
        for n in range(nEnc):
            stimA, stimB = makeSignal(t, fPre, fNMDA, dt=dt, seed=n)
            data = go(
                NPre=NPre, N=N, t=t, dt=dt, stimA=stimA, stimB=stimB,
                dPreA=dPreA, dPreB=dPreB, dFdfw=dFdfw, dBio=dBio,
                ePreA=ePreA, ePreB=ePreB, eFdfw=eFdfw, eBio=eBio,
                stage=5)
            eBio = data['eBio']
            plotActivity(t, dt, fS, data['times'], data['ens'], data['tarEns'], "integrateNMDA", "ensEns")
            np.savez("data/integrateNMDA.npz",
                dPreA=dPreA, dPreB=dPreB, dPreC=dPreC, dFdfw=dFdfw, dBio=dBio, dInh=dInh, dNeg=dNeg,
                ePreA=ePreA, ePreB=ePreB, ePreC=ePreC, eFdfw=eFdfw, eBio=eBio)

    if 6 in load:
        eNeg = np.load(file)['eNeg']
        eInh = np.load(file)['eInh']
    else:
        print("encoders from [ens, inh] to fdfw")
        eNeg = np.zeros((N, N, 1))
        eInh = np.zeros((N, N, 1))
        for n in range(nEnc):
            stimA, stimB = makeSignal(t, fPre, fNMDA, dt=dt, seed=n)
            stimC = lambda t: 0 if t<c else 1
            data = go(
                NPre=NPre, N=N, t=t, dt=dt, stimA=stimA, stimB=stimB, stimC=stimC,
                dPreA=dPreA, dPreB=dPreB, dPreC=dPreC, dFdfw=dFdfw, dBio=dBio, dNeg=dNeg, dInh=dInh,
                ePreA=ePreA, ePreB=ePreB, ePreC=ePreC, eFdfw=eFdfw, eBio=eBio, eNeg=eNeg, eInh=eInh,
                stage=6)
            eNeg = data['eNeg']
            eInh = data['eInh']
            plotActivity(t, dt, fS, data['times'], data['fdfw2'], data['tarFdfw2'], "integrateNMDA", "ensFdfw")
            plotActivity(t, dt, fS, data['times'], data['fdfw4'], data['tarFdfw4'], "integrateNMDA", "inhFdfw")
            np.savez("data/integrateNMDA.npz",
                dPreA=dPreA, dPreB=dPreB, dPreC=dPreC, dFdfw=dFdfw, dBio=dBio, dNeg=dNeg, dInh=dInh,
                ePreA=ePreA, ePreB=ePreB, ePreC=ePreC, eFdfw=eFdfw, eBio=eBio, eNeg=eNeg, eInh=eInh)
    # eNeg = None
    # eInh = None

    print("testing")
    vals = np.linspace(-1, 1, nTest)
    for test in range(nTest):

        stimA = lambda t: vals[test] if t<c else 0
        stimC = lambda t: 0 if t<c else 1
        data = go(
            NPre=NPre, N=N, t=t, dt=dt, stimA=stimA, stimC=stimC,
            dPreA=dPreA, dPreC=dPreC, dFdfw=dFdfw, dBio=dBio, dNeg=Tneg*dNeg, dInh=dInh,
            ePreA=ePreA, ePreC=ePreC, eFdfw=eFdfw, eBio=eBio, eNeg=eNeg, eInh=eInh,
            stage=7, c=c, DA=DATest)
        aFdfw = fNMDA.filt(data['fdfw'], dt=dt)
        aBio = fNMDA.filt(data['ens'], dt=dt)
        aInh = fGABA.filt(data['inh'], dt=dt)
        xhatFdfw = np.dot(aFdfw, dFdfw)/Tff
        xhatBio = np.dot(aBio, dBio)
        xhatInh = np.dot(aInh, dInh)
        xFdfw = fNMDA.filt(fNMDA.filt(fPre.filt(data['inptA'], dt=dt), dt=dt), dt=dt)
        xBio = xFdfw[int(c/dt)] * np.ones_like(data['times'])
        errorBio = rmse(xhatBio[int(c/dt):], xBio[int(c/dt):])
        fig, ax = plt.subplots()
        ax.plot(data['times'], xFdfw, alpha=0.5, label="input (filtered)")
        ax.plot(data['times'], xhatFdfw, alpha=0.5, label="fdfw")
        ax.plot(data['times'], xBio, label="target")
        ax.plot(data['times'], xhatBio, alpha=0.5, label="ens, rmse=%.3f"%errorBio)
        ax.plot(data['times'], xhatInh, alpha=0.5, label="inh")
        ax.axvline(c, label="cutoff")
        ax.set(xlabel='time', ylabel='state', xlim=((0, t)), ylim=((-1.5, 1.5)))
        ax.legend(loc='upper left')
        sns.despine()
        fig.savefig("plots/integrateNMDA_testFlat%s.pdf"%test)

        stimA, stimB = makeSignal(t, fPre, fNMDA, dt=dt, seed=200+test)
        stimC = lambda t: 0
        data = go(NPre=NPre, N=N, t=t, dt=dt, stimA=stimA, stimB=stimB, stimC=stimC, Tff=Tff,
            dPreA=dPreA, dPreC=dPreC, dFdfw=dFdfw, dBio=dBio, dNeg=None, dInh=dInh,
            ePreA=ePreA, ePreC=ePreC, eFdfw=eFdfw, eBio=eBio, eNeg=None, eInh=eInh,
            stage=7, DA=DATest)
        aFdfw = fNMDA.filt(data['fdfw'], dt=dt)
        aBio = fNMDA.filt(data['ens'], dt=dt)
        xhatFdfw = np.dot(aFdfw, dFdfw)
        xhatBio = np.dot(aBio, dBio)
        xhatNeg = np.dot(aBio, dNeg)
        xFdfw = fNMDA.filt(fPre.filt(data['inptA'], dt=dt), dt=dt)
        xBio = fNMDA.filt(fPre.filt(data['inptB'], dt=dt), dt=dt)
        errorBio = rmse(xhatBio, xBio)
        fig, ax = plt.subplots()
        ax.plot(data['times'], xhatFdfw, alpha=0.5, label="fdfw")
        ax.plot(data['times'], fNMDA.filt(Tff*xFdfw, dt=dt), alpha=0.5, label="input (filtered)")
        ax.plot(data['times'], xBio,  label="target")
        ax.plot(data['times'], xhatBio, alpha=0.5, label="ens, rmse=%.3f"%errorBio)
        ax.set(xlabel='time', ylabel='state', xlim=((0, t)), ylim=((-1.5, 1.5)))
        ax.legend(loc='upper left')
        sns.despine()
        fig.savefig("plots/integrateNMDA_testWhite%s.pdf"%test)
        plt.close('all')
Exemplo n.º 21
0
def medium(N=300,
           t=10,
           dt=0.001,
           nTrain=5,
           nTest=5,
           fAMPA=DoubleExp(0.55e-3, 2.2e-3),
           fNMDA=DoubleExp(2.3e-3, 95.0e-3),
           fGABA=DoubleExp(0.5e-3, 1.5e-3),
           daAMPA=0.8,
           daNMDA=0.7,
           daGABA=0.8,
           kEnsEnsAMPA=0.2,
           kEnsEnsNMDA=0.85,
           kEnsInhAMPA=0.1,
           kEnsInhNMDA=1.0,
           kGABAFdfw=-4.0,
           kGABAEns=-1e-3,
           T=0.05):

    print("readout decoders for fdfw")
    targetsAMPA = np.zeros((1, 1))
    targetsNMDA = np.zeros((1, 1))
    asAMPA = np.zeros((1, N))
    asNMDA = np.zeros((1, N))
    for n in range(nTrain):
        stim = makeSignalEven(t, dt=dt, value=(n + 1) / nTrain)
        #         stim = makeSignalPos(t, fNMDA, dt=dt, seed=n)
        data = goLIF(N=N,
                     stim=stim,
                     t=t,
                     dt=dt,
                     fAMPA=fAMPA,
                     fNMDA=fNMDA,
                     fGABA=fGABA)
        asAMPA = np.append(asAMPA, fAMPA.filt(data['fdfw']), axis=0)
        asNMDA = np.append(asNMDA, fNMDA.filt(data['fdfw']), axis=0)
        targetsAMPA = np.append(targetsAMPA, fAMPA.filt(data['inpt']), axis=0)
        targetsNMDA = np.append(targetsNMDA, fNMDA.filt(data['inpt']), axis=0)
#     dFdfwEnsAMPA, _ = LstsqL2(reg=1e-1)(asAMPA, np.ravel(targetsAMPA))
#     dFdfwEnsNMDA, _ = LstsqL2(reg=1e-1)(asNMDA, np.ravel(targetsNMDA))
    dFdfwEnsAMPA, _ = nnls(asAMPA, np.ravel(targetsAMPA))
    dFdfwEnsNMDA, _ = nnls(asNMDA, np.ravel(targetsNMDA))
    dFdfwEnsAMPA = dFdfwEnsAMPA.reshape((N, 1))
    dFdfwEnsNMDA = dFdfwEnsNMDA.reshape((N, 1))
    xhatFFAMPA = np.dot(asAMPA, dFdfwEnsAMPA)
    xhatFFNMDA = np.dot(asNMDA, dFdfwEnsNMDA)
    fig, ax = plt.subplots()
    ax.plot(targetsAMPA, linestyle="--", label='target (AMPA)')
    ax.plot(targetsNMDA, linestyle="--", label='target (NMDA)')
    ax.plot(xhatFFAMPA, alpha=0.5, label='fdfw (AMPA)')
    ax.plot(xhatFFNMDA, alpha=0.5, label='fdfw (NMDA)')
    ax.legend(loc="upper right")
    ax.set(ylim=((0, 1)), xlabel="time (s)", ylabel=r"$\mathbf{\hat{x}}(t)$")
    fig.savefig("plots/gatedMemory_goLIF_fdfw.pdf")

    print(
        "readout decoders for ens, given AMPA and NMDA from fdfw in high DA condition"
    )
    targetsAMPA = np.zeros((1, 1))
    targetsNMDA = np.zeros((1, 1))
    asAMPA = np.zeros((1, N))
    asNMDA = np.zeros((1, N))
    for n in range(nTrain):
        stim = makeSignalEven(t, dt=dt, value=(n + 1) / nTrain)
        #         stim = makeSignalPos(t, fNMDA, dt=dt, seed=n)
        data = goLIF(N=N,
                     stim=stim,
                     t=t,
                     dt=dt,
                     fAMPA=fAMPA,
                     fNMDA=fNMDA,
                     fGABA=fGABA,
                     dFdfwEnsAMPA=0 * dFdfwEnsAMPA,
                     dFdfwEnsNMDA=dFdfwEnsNMDA)
        asAMPA = np.append(asAMPA, fAMPA.filt(data['ens']), axis=0)
        asNMDA = np.append(asNMDA, fNMDA.filt(data['ens']), axis=0)
        targetsAMPA = np.append(targetsAMPA,
                                fAMPA.filt(fNMDA.filt(data['inpt'])),
                                axis=0)
        targetsNMDA = np.append(targetsNMDA,
                                fNMDA.filt(fNMDA.filt(data['inpt'])),
                                axis=0)


#     dEnsEnsAMPA, _ = LstsqL2(reg=1e-1)(asAMPA, kAMPA*np.ravel(targetsAMPA))
#     dEnsEnsNMDA, _ = LstsqL2(reg=1e-1)(asNMDA, kNMDA*np.ravel(targetsNMDA))
    dEnsEnsAMPA, _ = nnls(asAMPA, kEnsEnsAMPA * np.ravel(targetsAMPA))
    dEnsEnsNMDA, _ = nnls(asNMDA, kEnsEnsNMDA * np.ravel(targetsNMDA))
    dEnsEnsAMPA = dEnsEnsAMPA.reshape((N, 1))
    dEnsEnsNMDA = dEnsEnsNMDA.reshape((N, 1))
    dEnsInhAMPA, _ = nnls(asAMPA, kEnsInhAMPA * np.ravel(targetsAMPA))
    dEnsInhNMDA, _ = nnls(asNMDA, kEnsInhNMDA * np.ravel(targetsNMDA))
    dEnsInhAMPA = dEnsInhAMPA.reshape((N, 1))
    dEnsInhNMDA = dEnsInhNMDA.reshape((N, 1))
    xhatFBAMPA = np.dot(asAMPA, dEnsEnsAMPA)
    xhatFBNMDA = np.dot(asNMDA, dEnsEnsNMDA)
    xhatConstAMPA = np.dot(asAMPA, dEnsInhAMPA)
    xhatConstNMDA = np.dot(asNMDA, dEnsInhNMDA)
    fig, ax = plt.subplots()
    ax.plot(kEnsEnsAMPA * targetsAMPA, linestyle="--", label='target (AMPA)')
    ax.plot(kEnsEnsNMDA * targetsNMDA, linestyle="--", label='target (NMDA)')
    ax.plot(xhatFBAMPA, alpha=0.5, label='ens (AMPA)')
    ax.plot(xhatFBNMDA, alpha=0.5, label='ens (NMDA)')
    ax.legend(loc="upper right")
    ax.set(ylim=((0, 1)), xlabel="time (s)", ylabel=r"$\mathbf{\hat{x}}(t)$")
    fig.savefig("plots/gatedMemory_goLIF_ens.pdf")

    print(
        "readout decoders for inh, given ff and fb AMPA and NMDA and high DA")
    targetsGABA = np.zeros((1, 1))
    asGABA = np.zeros((1, N))
    for n in range(nTrain):
        stim = makeSignalEven(t, dt=dt, value=(n + 1) / nTrain)
        #         stim = makeSignalPos(t, fNMDA, dt=dt, seed=n)
        data = goLIF(N=N,
                     stim=stim,
                     t=t,
                     dt=dt,
                     fAMPA=fAMPA,
                     fNMDA=fNMDA,
                     fGABA=fGABA,
                     dFdfwEnsAMPA=daAMPA * dFdfwEnsAMPA,
                     dFdfwEnsNMDA=dFdfwEnsNMDA,
                     dEnsInhAMPA=daAMPA * dEnsInhAMPA,
                     dEnsInhNMDA=dEnsInhNMDA)
        asGABA = np.append(asGABA, fGABA.filt(data['inh']), axis=0)
        targetsGABA = np.append(targetsGABA,
                                fGABA.filt(fNMDA.filt(data['inpt'])),
                                axis=0)
    dInhEnsGABA, _ = nnls(-asGABA, kGABAEns * np.ravel(targetsGABA))
    dInhFdfwGABA, _ = nnls(-asGABA, kGABAFdfw * np.ravel(targetsGABA))
    dInhEnsGABA = -dInhEnsGABA.reshape((N, 1))
    dInhFdfwGABA = -dInhFdfwGABA.reshape((N, 1))
    xhatInhEns = np.dot(asGABA, dInhEnsGABA)
    xhatInhFdfw = np.dot(asGABA, dInhFdfwGABA)
    fig, ax = plt.subplots()
    ax.plot(kGABAEns * targetsGABA, linestyle="--", label='target (ens)')
    ax.plot(kGABAFdfw * targetsGABA, linestyle="--", label='target (fdfw)')
    ax.plot(xhatInhEns, alpha=0.5, label='inh (ens)')
    ax.plot(xhatInhFdfw, alpha=0.5, label='inh (fdfw)')
    ax.legend(loc="upper right")
    ax.set(ylim=((-1, 0)), xlabel="time (s)", ylabel=r"$\mathbf{\hat{x}}(t)$")
    fig.savefig("plots/gatedMemory_goLIF_inh.pdf")

    # Test in high and low DA
    print('testing with high and low DA')
    for n in range(nTest):
        stim = makeSignalSquare(t, dt=dt, seed=100 + n)
        #         data = goLIF(N=N, stim=lambda t: 0, x0=n/nTest, t=t, dt=dt,
        data = goLIF(N=N,
                     stim=stim,
                     t=t,
                     dt=dt,
                     fAMPA=fAMPA,
                     fNMDA=fNMDA,
                     fGABA=fGABA,
                     dFdfwEnsAMPA=daAMPA * T * dFdfwEnsAMPA,
                     dFdfwEnsNMDA=T * dFdfwEnsNMDA,
                     dEnsEnsAMPA=daAMPA * dEnsEnsAMPA,
                     dEnsEnsNMDA=dEnsEnsNMDA,
                     dEnsInhAMPA=daAMPA * dEnsInhAMPA,
                     dEnsInhNMDA=dEnsInhNMDA,
                     dInhEnsGABA=dInhEnsGABA,
                     dInhFdfwGABA=dInhFdfwGABA)
        aFdfwNMDA = fNMDA.filt(data['fdfw'])
        targetFdfwNMDA = fNMDA.filt(data['inpt'])
        xhatFdfwNMDA = np.dot(aFdfwNMDA, dFdfwEnsNMDA)
        aEnsNMDA = fNMDA.filt(data['ens'])
        targetIntgNMDA = fNMDA.filt(data['intg'])
        #         targetFlatNMDA = n/nTest*np.ones((aEnsNMDA.shape[0]))
        xhatIntgNMDA = np.dot(aEnsNMDA, dEnsEnsNMDA)
        fig, (ax, ax2) = plt.subplots(ncols=1, nrows=2, sharex=True)
        ax.plot(data['times'],
                targetFdfwNMDA,
                linestyle="--",
                label='input (NMDA)')
        ax.plot(data['times'], xhatFdfwNMDA, alpha=0.5, label='fdfw (NMDA)')
        ax.plot(data['times'],
                targetIntgNMDA,
                linestyle="--",
                label='integral (NMDA)')
        #         ax.plot(data['times'], targetFlatNMDA, linestyle="--", label='flat (NMDA)')
        ax.plot(data['times'], xhatIntgNMDA, alpha=0.5, label='ens (NMDA)')
        ax.legend(loc="upper right")
        ax.set(ylim=((0, 1)), ylabel=r"$\mathbf{\hat{x}}(t)$", title="high DA")
        stim = makeSignalSquare(t, dt=dt, seed=100 + n)
        data = goLIF(N=N,
                     stim=stim,
                     t=t,
                     dt=dt,
                     fAMPA=fAMPA,
                     fNMDA=fNMDA,
                     fGABA=fGABA,
                     dFdfwEnsAMPA=T * dFdfwEnsAMPA,
                     dFdfwEnsNMDA=daNMDA * T * dFdfwEnsNMDA,
                     dEnsEnsAMPA=dEnsEnsAMPA,
                     dEnsEnsNMDA=daNMDA * dEnsEnsNMDA,
                     dEnsInhAMPA=dEnsInhAMPA,
                     dEnsInhNMDA=daNMDA * dEnsInhNMDA,
                     dInhEnsGABA=daGABA * dInhEnsGABA,
                     dInhFdfwGABA=daGABA * dInhFdfwGABA)
        aFdfwNMDA = fNMDA.filt(data['fdfw'])
        targetFdfwNMDA = fNMDA.filt(data['inpt'])
        xhatFdfwNMDA = np.dot(aFdfwNMDA, dFdfwEnsNMDA)
        aEnsNMDA = fNMDA.filt(data['ens'])
        targetIntgNMDA = fNMDA.filt(data['intg'])
        #         targetFlatNMDA = n/nTest*np.ones((aEnsNMDA.shape[0]))
        xhatIntgNMDA = np.dot(aEnsNMDA, dEnsEnsNMDA)
        #         fig, ax = plt.subplots()
        ax2.plot(data['times'],
                 targetFdfwNMDA,
                 linestyle="--",
                 label='input (NMDA)')
        ax2.plot(data['times'], xhatFdfwNMDA, alpha=0.5, label='fdfw (NMDA)')
        ax2.plot(data['times'],
                 targetIntgNMDA,
                 linestyle="--",
                 label='integral (NMDA)')
        #         ax2.plot(data['times'], targetFlatNMDA, linestyle="--", label='flat (NMDA)')
        ax2.plot(data['times'], xhatIntgNMDA, alpha=0.5, label='ens (NMDA)')
        ax2.legend(loc="upper right")
        ax2.set(ylim=((0, 1)),
                xlabel="time (s)",
                ylabel=r"$\mathbf{\hat{x}}(t)$",
                title="low DA")
        fig.savefig("plots/gatedMemory_goLIF_test%s.pdf" % n)
Exemplo n.º 22
0
def go(NPre=100, N=100, t=10, c=None, seed=0, dt=0.001, Tff=0.3, tTrans=0.01,
        stage=None, alpha=3e-7, eMax=1e-1,
        fPre=DoubleExp(1e-3, 1e-1), fNMDA=DoubleExp(10.6e-3, 285e-3), fS=DoubleExp(1e-3, 1e-1),
        dPreA=None, dPreB=None, dPreC=None, dFdfw=None, dBio=None, dNeg=None, dInh=None,
        ePreA=None, ePreB=None, ePreC=None, eFdfw=None, eBio=None, eNeg=None, eInh=None,
        stimA=lambda t: 0, stimB=lambda t: 0, stimC=lambda t: 0, DA=lambda t: 0):

    if not c: c = t
    with nengo.Network(seed=seed) as model:
        inptA = nengo.Node(stimA)
        inptB = nengo.Node(stimB)
        inptC = nengo.Node(stimC)
        preA = nengo.Ensemble(NPre, 1, max_rates=Uniform(30, 30), seed=seed)
        preB = nengo.Ensemble(NPre, 1, max_rates=Uniform(30, 30), seed=seed)
        preC = nengo.Ensemble(NPre, 1, max_rates=Uniform(30, 30), seed=seed)
        fdfw = nengo.Ensemble(N, 1, neuron_type=Bio("Pyramidal", DA=DA), seed=seed)
        ens = nengo.Ensemble(N, 1, neuron_type=Bio("Pyramidal", DA=DA), seed=seed+1)
        inh = nengo.Ensemble(N, 1, neuron_type=Bio("Interneuron", DA=DA), seed=seed+2)
        tarFdfw = nengo.Ensemble(N, 1, max_rates=Uniform(30, 30), intercepts=Uniform(-0.8, 0.8), neuron_type=nengo.LIF(), seed=seed)
        tarEns = nengo.Ensemble(N, 1, max_rates=Uniform(30, 30), intercepts=Uniform(-0.8, 0.8), neuron_type=nengo.LIF(), seed=seed+1)
        tarInh = nengo.Ensemble(N, 1, max_rates=Uniform(30, 30), intercepts=Uniform(0.2, 0.8), encoders=Choice([[1]]), neuron_type=nengo.LIF(), seed=seed+2)
        cA = nengo.Connection(inptA, preA, synapse=None, seed=seed)
        cB = nengo.Connection(inptB, preB, synapse=None, seed=seed)
        cC = nengo.Connection(inptC, preC, synapse=None, seed=seed)
        pInptA = nengo.Probe(inptA, synapse=None)
        pInptB = nengo.Probe(inptB, synapse=None)
        pInptC = nengo.Probe(inptC, synapse=None)
        pPreA = nengo.Probe(preA.neurons, synapse=None)
        pPreB = nengo.Probe(preB.neurons, synapse=None)
        pPreC = nengo.Probe(preC.neurons, synapse=None)
        pFdfw = nengo.Probe(fdfw.neurons, synapse=None)
        pTarFdfw = nengo.Probe(tarFdfw.neurons, synapse=None)
        pEns = nengo.Probe(ens.neurons, synapse=None)
        pTarEns = nengo.Probe(tarEns.neurons, synapse=None)
        pInh = nengo.Probe(inh.neurons, synapse=None)
        pTarInh = nengo.Probe(tarInh.neurons, synapse=None)
        if stage==1:
            nengo.Connection(inptA, tarFdfw, synapse=fPre, seed=seed)
            nengo.Connection(inptB, tarEns, synapse=fPre, seed=seed)
            nengo.Connection(inptC, tarInh, synapse=fPre, seed=seed)
            c1 = nengo.Connection(preA, fdfw, synapse=fPre, solver=NoSolver(dPreA), seed=seed)
            c2 = nengo.Connection(preB, ens, synapse=fPre, solver=NoSolver(dPreB), seed=seed)
            c3 = nengo.Connection(preC, inh, synapse=fPre, solver=NoSolver(dPreC), seed=seed)
            learnEncoders(c1, tarFdfw, fS, alpha=alpha, eMax=eMax, tTrans=tTrans)
            learnEncoders(c2, tarEns, fS, alpha=3*alpha, eMax=10*eMax, tTrans=tTrans)
            learnEncoders(c3, tarInh, fS, alpha=alpha/3, eMax=eMax, tTrans=tTrans)
        if stage==2:
            c1 = nengo.Connection(preA, fdfw, synapse=fPre, solver=NoSolver(dPreA), seed=seed)
            c2 = nengo.Connection(preC, inh, synapse=fPre, solver=NoSolver(dPreC), seed=seed)
        if stage==3:
            cB.synapse = fNMDA
            nengo.Connection(inptA, tarFdfw, synapse=fPre, seed=seed)
            nengo.Connection(inptB, tarEns, synapse=fPre, seed=seed)
            nengo.Connection(tarFdfw, tarEns, synapse=fNMDA, transform=Tff, seed=seed)
            c1 = nengo.Connection(preA, fdfw, synapse=fPre, solver=NoSolver(dPreA), seed=seed)
            c2 = nengo.Connection(preB, ens, synapse=fPre, solver=NoSolver(dPreB), seed=seed)
            c3 = nengo.Connection(fdfw, ens, synapse=NMDA(), solver=NoSolver(dFdfw), seed=seed)
            learnEncoders(c3, tarEns, fS, alpha=alpha, eMax=eMax, tTrans=tTrans)
        if stage==4:
            cB.synapse = fNMDA
            c1 = nengo.Connection(preA, fdfw, synapse=fPre, solver=NoSolver(dPreA), seed=seed)
            c2 = nengo.Connection(preB, ens, synapse=fPre, solver=NoSolver(dPreB), seed=seed)
            c3 = nengo.Connection(fdfw, ens, synapse=NMDA(), solver=NoSolver(dFdfw), seed=seed)
        if stage==5:
            preB2 = nengo.Ensemble(NPre, 1, max_rates=Uniform(30, 30), seed=seed)
            ens2 = nengo.Ensemble(N, 1, neuron_type=Bio("Pyramidal", DA=DA), seed=seed+1)
            ens3 = nengo.Ensemble(N, 1, neuron_type=Bio("Pyramidal", DA=DA), seed=seed+1)
            nengo.Connection(inptB, preB2, synapse=fNMDA, seed=seed)
            c1 = nengo.Connection(preA, fdfw, synapse=fPre, solver=NoSolver(dPreA), seed=seed)
            c2 = nengo.Connection(fdfw, ens, synapse=NMDA(), solver=NoSolver(dFdfw), seed=seed)
            c3 = nengo.Connection(preB, ens2, synapse=fPre, solver=NoSolver(dPreB), seed=seed)
            c4 = nengo.Connection(ens2, ens, synapse=NMDA(), solver=NoSolver(dBio), seed=seed)
            c5 = nengo.Connection(fdfw, ens3, synapse=NMDA(), solver=NoSolver(dFdfw), seed=seed)
            c6 = nengo.Connection(preB2, ens3, synapse=fPre, solver=NoSolver(dPreB), seed=seed)
            learnEncoders(c4, ens3, fS, alpha=alpha, eMax=eMax, tTrans=tTrans)
            pTarEns = nengo.Probe(ens3.neurons, synapse=None)
        if stage==6:
            preA2 = nengo.Ensemble(NPre, 1, max_rates=Uniform(30, 30), seed=seed)
            fdfw2 = nengo.Ensemble(N, 1, neuron_type=Bio("Pyramidal", DA=DA), seed=seed)
            fdfw3 = nengo.Ensemble(N, 1, neuron_type=Bio("Pyramidal", DA=DA), seed=seed)
            fdfw4 = nengo.Ensemble(N, 1, neuron_type=Bio("Pyramidal", DA=DA), seed=seed)
            tarFdfw4 = nengo.Ensemble(N, 1, max_rates=Uniform(30, 30), intercepts=Uniform(-0.8, 0.8), neuron_type=nengo.LIF(), seed=seed)
            nengo.Connection(inptA, tarFdfw4, synapse=fPre, seed=seed)
            nengo.Connection(inptB, preA2, synapse=fNMDA, seed=seed)
            nengo.Connection(inptC, tarInh, synapse=fPre, seed=seed)
            nengo.Connection(tarInh, tarFdfw4.neurons, synapse=None, transform=-1e2*np.ones((N, 1)), seed=seed)
            c1 = nengo.Connection(preB, ens, synapse=fPre, solver=NoSolver(dPreB), seed=seed)
            c2 = nengo.Connection(ens, fdfw2, synapse=NMDA(), solver=NoSolver(dNeg), seed=seed)
            c3 = nengo.Connection(preA2, fdfw3, synapse=fPre, solver=NoSolver(dPreA), seed=seed)
            c4 = nengo.Connection(preA, fdfw4, synapse=fPre, solver=NoSolver(dPreA), seed=seed)
            c5 = nengo.Connection(preC, inh, synapse=fPre, solver=NoSolver(dPreC), seed=seed)
            c6 = nengo.Connection(inh, fdfw4, synapse=GABA(), solver=NoSolver(dInh), seed=seed)
            learnEncoders(c2, fdfw3, fS, alpha=alpha, eMax=eMax, tTrans=tTrans)
            learnEncoders(c6, tarFdfw4, fS, alpha=1e3*alpha, eMax=1e3*eMax, tTrans=tTrans, inh=True)
            pFdfw2 = nengo.Probe(fdfw2.neurons, synapse=None)
            pFdfw4 = nengo.Probe(fdfw4.neurons, synapse=None)
            pTarFdfw2 = nengo.Probe(fdfw3.neurons, synapse=None)
            pTarFdfw4 = nengo.Probe(tarFdfw4.neurons, synapse=None)
        if stage==7:
            c1 = nengo.Connection(preA, fdfw, synapse=fPre, solver=NoSolver(dPreA), seed=seed)
            c2 = nengo.Connection(fdfw, ens, synapse=NMDA(), solver=NoSolver(dFdfw), seed=seed)
            c3 = nengo.Connection(ens, ens, synapse=NMDA(), solver=NoSolver(dBio), seed=seed)
            c4 = nengo.Connection(ens, fdfw, synapse=NMDA(), solver=NoSolver(dNeg), seed=seed)
            c5 = nengo.Connection(preC, inh, synapse=fPre, solver=NoSolver(dPreC), seed=seed)
            c6 = nengo.Connection(inh, fdfw, synapse=GABA(), solver=NoSolver(dInh), seed=seed)

    with nengo.Simulator(model, seed=seed, dt=dt, progress_bar=False) as sim:
        if stage==1:
            setWeights(c1, dPreA, ePreA)
            setWeights(c2, dPreB, ePreB)
            setWeights(c3, dPreC, ePreC)
        if stage==2:
            setWeights(c1, dPreA, ePreA)
            setWeights(c2, dPreC, ePreC)
        if stage==3:
            setWeights(c1, dPreA, ePreA)
            setWeights(c2, dPreB, ePreB)
            setWeights(c3, dFdfw, eFdfw)
        if stage==4:
            setWeights(c1, dPreA, ePreA)
            setWeights(c2, dPreB, ePreB)
            setWeights(c3, dFdfw, eFdfw)
        if stage==5:
            setWeights(c1, dPreA, ePreA)
            setWeights(c2, dFdfw, eFdfw)
            setWeights(c3, dPreB, ePreB)
            setWeights(c4, dBio, eBio)
            setWeights(c5, dFdfw, eFdfw)
            setWeights(c6, dPreB, ePreB)
        if stage==6:
            setWeights(c1, dPreB, ePreB)
            setWeights(c2, dNeg, eNeg)
            setWeights(c3, dPreA, ePreA)
            setWeights(c4, dPreA, ePreA)
            setWeights(c5, dPreC, ePreC)
            setWeights(c6, dInh, eInh)
        if stage==7:
            setWeights(c1, dPreA, ePreA)
            setWeights(c2, dFdfw, eFdfw)
            setWeights(c3, dBio, eBio)
            setWeights(c4, dNeg, eNeg)
            setWeights(c5, dPreC, ePreC)
            setWeights(c6, dInh, eInh)
        neuron.h.init()
        sim.run(t, progress_bar=True)
        reset_neuron(sim, model) 

    ePreA = c1.e if stage==1 else ePreA
    ePreB = c2.e if stage==1 else ePreB
    ePreC = c3.e if stage==1 else ePreC
    eFdfw = c3.e if stage==3 else eFdfw
    eBio = c4.e if stage==5 else eBio
    eNeg = c2.e if stage==6 else eNeg
    eInh = c6.e if stage==6 else eInh

    return dict(
        times=sim.trange(),
        inptA=sim.data[pInptA],
        inptB=sim.data[pInptB],
        inptC=sim.data[pInptC],
        preA=sim.data[pPreA],
        preB=sim.data[pPreB],
        preC=sim.data[pPreC],
        fdfw=sim.data[pFdfw],
        ens=sim.data[pEns],
        inh=sim.data[pInh],
        tarFdfw=sim.data[pTarFdfw],
        tarEns=sim.data[pTarEns],
        tarInh=sim.data[pTarInh],
        fdfw2=sim.data[pFdfw2] if stage==6 else None,
        fdfw4=sim.data[pFdfw4] if stage==6 else None,
        tarFdfw2=sim.data[pTarFdfw2] if stage==6 else None,
        tarFdfw4=sim.data[pTarFdfw4] if stage==6 else None,
        ePreA=ePreA,
        ePreB=ePreB,
        ePreC=ePreC,
        eFdfw=eFdfw,
        eBio=eBio,
        eNeg=eNeg,
        eInh=eInh,
    )
Exemplo n.º 23
0
def go(NPre=100,
       N=100,
       N2=30,
       t=10,
       m=Uniform(30, 30),
       i=Uniform(-0.8, 0.8),
       seed=0,
       dt=0.001,
       f=DoubleExp(1e-3, 1e-1),
       fS=DoubleExp(1e-3, 1e-1),
       neuron_type=LIF(),
       d1=None,
       d2=None,
       f1=None,
       f2=None,
       e1=None,
       e2=None,
       l1=False,
       l2=False,
       stim=lambda t: [0, 0]):

    if not f1: f1 = f
    if not f2: f2 = f
    if not np.any(d2): d2 = np.zeros((N, 1))
    with nengo.Network(seed=seed) as model:
        # Stimulus and Nodes
        inpt = nengo.Node(stim)
        tar = nengo.Ensemble(1, 2, neuron_type=nengo.Direct())
        tar2 = nengo.Ensemble(1, 1, neuron_type=nengo.Direct())
        pre = nengo.Ensemble(NPre,
                             2,
                             radius=2,
                             max_rates=m,
                             seed=seed,
                             neuron_type=LIF())
        ens = nengo.Ensemble(N,
                             2,
                             radius=2,
                             max_rates=m,
                             intercepts=i,
                             neuron_type=neuron_type,
                             seed=seed)
        ens2 = nengo.Ensemble(N2,
                              1,
                              max_rates=m,
                              intercepts=i,
                              neuron_type=neuron_type,
                              seed=seed + 1)
        nengo.Connection(inpt, pre, synapse=None, seed=seed)
        nengo.Connection(inpt, tar, synapse=f, seed=seed)
        nengo.Connection(tar,
                         tar2,
                         synapse=f,
                         function=multiply,
                         seed=seed + 1)
        c1 = nengo.Connection(pre,
                              ens,
                              synapse=f1,
                              seed=seed,
                              solver=NoSolver(d1))
        if isinstance(neuron_type, Bio):
            c2 = nengo.Connection(ens,
                                  ens2,
                                  synapse=f2,
                                  seed=seed + 1,
                                  function=multiply)
        else:
            c2 = nengo.Connection(ens.neurons,
                                  ens2,
                                  synapse=f2,
                                  seed=seed + 1,
                                  transform=d2.T)
        pInpt = nengo.Probe(inpt, synapse=None)
        pPre = nengo.Probe(pre.neurons, synapse=None)
        pEns = nengo.Probe(ens.neurons, synapse=None)
        pEns2 = nengo.Probe(ens2.neurons, synapse=None)
        pTar = nengo.Probe(tar, synapse=None)
        pTar2 = nengo.Probe(tar2, synapse=None)
        # Encoder Learning (Bio)
        if l1:
            tarEns = nengo.Ensemble(N,
                                    2,
                                    radius=2,
                                    max_rates=m,
                                    intercepts=i,
                                    neuron_type=nengo.LIF(),
                                    seed=seed)
            nengo.Connection(tar, tarEns, synapse=None, seed=seed)
            learnEncoders(c1, tarEns, fS, alpha=3e-8)
            pTarEns = nengo.Probe(tarEns.neurons, synapse=None)
        if l2:
            tarEns2 = nengo.Ensemble(N2,
                                     1,
                                     max_rates=m,
                                     intercepts=i,
                                     neuron_type=nengo.LIF(),
                                     seed=seed + 1)
            nengo.Connection(tar2, tarEns2, synapse=None)
            #             nengo.Connection(ens.neurons, tarEns2, synapse=f2, transform=d2.T, seed=seed+1)
            learnEncoders(c2, tarEns2, fS, alpha=1e-7)
            pTarEns2 = nengo.Probe(tarEns2.neurons, synapse=None)
            pTarState = nengo.Probe(tarEns2, synapse=f)

    with nengo.Simulator(model, seed=seed, dt=dt, progress_bar=False) as sim:
        if isinstance(neuron_type, Bio):
            setWeights(c1, d1, e1)
            setWeights(c2, d2, e2)
            neuron.h.init()
            sim.run(t, progress_bar=True)
            reset_neuron(sim, model)
        else:
            sim.run(t, progress_bar=True)

    e1 = c1.e if l1 else e1
    e2 = c2.e if l2 else e2

    return dict(
        times=sim.trange(),
        inpt=sim.data[pInpt],
        pre=sim.data[pPre],
        ens=sim.data[pEns],
        ens2=sim.data[pEns2],
        tar=sim.data[pTar],
        tar2=sim.data[pTar2],
        tarEns=sim.data[pTarEns] if l1 else None,
        tarEns2=sim.data[pTarEns2] if l2 else None,
        tarState=sim.data[pTarState] if l2 else None,
        e1=e1,
        e2=e2,
    )
Exemplo n.º 24
0
def run(N=3000, neuron_type=LIF(), tTrain=200, tTest=100, tTransTrain=20, tTransTest=20,
    nTrain=1, nTest=10, dt=0.001, dtSampleTrain=0.003, dtSampleTest=0.01, seed=0,
    f2=10, reg=1e-3, reg2=1e-3, evals=100, r=30, load=False, file=None,
    tauRiseMin=3e-2, tauRiseMax=6e-2, tauFallMin=2e-1, tauFallMax=3e-1):

    print('\nNeuron Type: %s'%neuron_type)
    rng = np.random.RandomState(seed=seed)
    timeStepsTrain = int((tTrain-tTransTrain)/dtSampleTrain)
    tStart = int(tTransTrain/dtSampleTrain)
    tFlat = int((tTrain-tTransTrain)/dtSampleTrain)*nTrain
    if load:
        d = np.load(file)['d']
        d2 = np.load(file)['d2']
        tauRise = np.load(file)['tauRise']
        tauFall = np.load(file)['tauFall']
        f = DoubleExp(tauRise, tauFall)
    else:
        print('decoders for ens')
        spikes = np.zeros((nTrain, timeStepsTrain, N))
        spikes2 = np.zeros((nTrain, timeStepsTrain, N))
        targets = np.zeros((nTrain, timeStepsTrain, 3))
        targets2 = np.zeros((nTrain, timeStepsTrain, 3))
        for n in range(nTrain):
            # IC = np.array([rng.uniform(-15, 15), rng.uniform(-20, 20), rng.uniform(10, 35)])
            IC = np.array([rng.uniform(-5, 5), rng.uniform(-5, 5), rng.uniform(20, 25)])
            data = go(N=N, neuron_type=neuron_type, l=True, t=tTrain, r=r, dt=dt, dtSample=dtSampleTrain, seed=seed, IC=IC)
            spikes[n] = data['ens'][-timeStepsTrain:]
            spikes2[n] = gaussian_filter1d(data['ens'], sigma=f2, axis=0)[-timeStepsTrain:]
            targets[n] = data['tar'][-timeStepsTrain:]
            targets2[n] = gaussian_filter1d(data['tar'], sigma=f2, axis=0)[-timeStepsTrain:]
        d, f, tauRise, tauFall, X, Y, error = decode(
            spikes, targets, nTrain, dt=dtSampleTrain, dtSample=dtSampleTrain, name="lorenzNew", evals=evals, reg=reg,
            tauRiseMin=tauRiseMin, tauRiseMax=tauRiseMax, tauFallMin=tauFallMin, tauFallMax=tauFallMax)
        spikes2 = spikes2.reshape((tFlat, N))
        targets2 = targets2.reshape((tFlat, 3))
        A2 = gaussian_filter1d(spikes2, sigma=f2, axis=0)[-timeStepsTrain:]
        Y2 = gaussian_filter1d(targets2, sigma=f2, axis=0)[-timeStepsTrain:]
        d2, _ = nengo.solvers.LstsqL2(reg=reg2)(spikes2, targets2)
        np.savez("data/lorenzNew_%s.npz"%neuron_type, d=d, d2=d2, tauRise=tauRise, tauFall=tauFall, f2=f2)
        X2 = np.dot(A2, d2)[-timeStepsTrain:]
        error = plotLorenz(X, X2, Y2, neuron_type, "train")

    print("testing")
    tStart = int(tTransTest/dtSampleTest)
    errors = np.zeros((nTest))
    rng = np.random.RandomState(seed=100+seed)
    for test in range(nTest):
        # IC = np.array([rng.uniform(-15, 15), rng.uniform(-20, 20), rng.uniform(10, 35)])
        IC = np.array([rng.uniform(-5, 5), rng.uniform(-5, 5), rng.uniform(20, 25)])
        data = go(d=d, f=f, N=N, neuron_type=neuron_type, t=tTest, r=r, dt=dt, dtSample=dtSampleTest, seed=seed, IC=IC)
        A = f.filt(data['ens'], dt=dtSampleTest)
        A2 = gaussian_filter1d(data['ens'], sigma=f2, axis=0)[tStart:]
        X = np.dot(A, d)[tStart:]
        X2 = np.dot(A2, d2)[tStart:]
        Y = data['tar'][tStart:]
        Y2 = gaussian_filter1d(data['tar'], sigma=f2, axis=0)[tStart:]
        error = plotLorenz(X, X2, Y2, neuron_type, test)
        errors[test] = error
    _ = plotLorenz(Y, Y2, Y2, 'target', '')
    print('errors: ', errors)
    np.savez("data/lorenzNew_%s.npz"%neuron_type, d=d, d2=d2, tauRise=tauRise, tauFall=tauFall, f2=f2, errors=errors)
    return errors
Exemplo n.º 25
0
def run(NPre=300, N=100, t=20, tTrans=2, nTrain=1, nEnc=10, nTest=10, neuron_type=LIF(),
        dt=0.001, f=DoubleExp(1e-3, 1e-1), fS=DoubleExp(1e-3, 1e-1), freq=1, muFreq=1.0, sigmaFreq=0.1, reg=1e-1, tauRiseMax=5e-2, tDrive=0.2, base=False, load=False, file=None):

    print('\nNeuron Type: %s'%neuron_type)
    rng = np.random.RandomState(seed=0)
    if load:
        d1 = np.load(file)['d1']
        tauRise1 = np.load(file)['tauRise1']
        tauFall1 = np.load(file)['tauFall1']
        f1 = DoubleExp(tauRise1, tauFall1)
    else:
        print('readout decoders for pre')  
        spikes = np.zeros((nTrain, int(t/0.001), NPre))
        targets = np.zeros((nTrain, int(t/0.001), 2))
        for n in range(nTrain):
            data = go(NPre=NPre, N=N, t=t, dt=0.001, f=f, fS=fS, neuron_type=LIF(), freq=freq, phase=2*np.pi*(n/nTrain))
            spikes[n] = data['pre']
            targets[n] = f.filt(data['inpt'], dt=0.001)
        d1, f1, tauRise1, tauFall1, X, Y, error = decode(spikes, targets, nTrain, dt=0.001, tauRiseMax=tauRiseMax, name="oscillateNew")
        np.savez("data/oscillateNew_%s.npz"%neuron_type, d1=d1, tauRise1=tauRise1, tauFall1=tauFall1)
        times = np.arange(0, t*nTrain, 0.001)
        plotState(times, X, Y, error, "oscillateNew", "%s_pre"%neuron_type, t*nTrain)

    if load:
        e1 = np.load(file)['e1']
    elif isinstance(neuron_type, Bio):
        print("ens1 encoders")
        e1 = np.zeros((NPre, N, 2))
        for n in range(nEnc):
            data = go(d1=d1, e1=e1, f1=f1, NPre=NPre, N=N, t=t, dt=dt, f=f, fS=fS, neuron_type=neuron_type, freq=freq, phase=2*np.pi*(n/nEnc), l1=True)
            e1 = data['e1']
            np.savez("data/oscillateNew_%s.npz"%neuron_type, d1=d1, tauRise1=tauRise1, tauFall1=tauFall1, e1=e1)
            plotActivity(t, dt, fS, data['times'], data['ens'], data['tarEns'], "oscillateNew", "ens")
    else:
        e1 = np.zeros((NPre, N, 2))
        
    if load:
        d2 = np.load(file)['d2']
        tauRise2 = np.load(file)['tauRise2']
        tauFall2 = np.load(file)['tauFall2']
        f2 = DoubleExp(tauRise2, tauFall2)
    else:
        print('readout decoders for ens')
        spikes = np.zeros((nTrain, int((t-tTrans)/dt), N))
        targets = np.zeros((nTrain, int((t-tTrans)/dt), 2))
        for n in range(nTrain):
            data = go(d1=d1, e1=e1, f1=f1, NPre=NPre, N=N, t=t, dt=dt, f=f, fS=fS, neuron_type=neuron_type, freq=freq, phase=2*np.pi*(n/nTrain))
            spikes[n] = data['ens'][int(tTrans/dt):]
            targets[n] = f.filt(f.filt(data['tar'], dt=dt), dt=dt)[int(tTrans/dt):]
        d2, f2, tauRise2, tauFall2, X, Y, error = decode(spikes, targets, nTrain, dt=dt, name="oscillateNew", reg=reg, tauRiseMax=tauRiseMax)
        np.savez("data/oscillateNew_%s.npz"%neuron_type, d1=d1, tauRise1=tauRise1, tauFall1=tauFall1, e1=e1, d2=d2, tauRise2=tauRise2, tauFall2=tauFall2)
        times = np.arange(0, t*nTrain, dt)[:len(X)]
        plotState(times, X, Y, error, "oscillateNew", "%s_ens"%neuron_type, (t-tTrans)*nTrain)

    if load:
        e2 = np.load(file)['e2']
    #elif isinstance(neuron_type, Bio):
        print("ens2 encoders")
        #e2 = np.zeros((N, N, 2))
        for n in range(nEnc):
            data = go(d1=d1, e1=e1, f1=f1, d2=d2, e2=e2, NPre=NPre, N=N, t=t, dt=dt, f=f, fS=fS, neuron_type=neuron_type, freq=freq, phase=2*np.pi*(n/nEnc), l2=True)
            e2 = data['e2']
            np.savez("data/oscillateNew_%s.npz"%neuron_type, d1=d1, tauRise1=tauRise1, tauFall1=tauFall1, e1=e1, d2=d2, tauRise2=tauRise2, tauFall2=tauFall2, e2=e2)
            plotActivity(t, dt, fS, data['times'], data['ens2'], data['tarEns2'], "oscillateNew", "ens2")
    else:
        e2 = np.zeros((N, N, 2))
        np.savez("data/oscillateNew_%s.npz"%neuron_type, d1=d1, tauRise1=tauRise1, tauFall1=tauFall1, e1=e1, d2=d2, tauRise2=tauRise2, tauFall2=tauFall2, e2=e2)

    print("testing")
    errors = np.zeros((nTest))
    for test in range(nTest):
        data = go(d1=d1, e1=e1, f1=f1, d2=d2, e2=e2, f2=f2, NPre=NPre, N=N, t=t+tTrans, dt=dt, f=f, fS=fS, neuron_type=neuron_type, freq=freq, phase=2*np.pi*(test/nTest), tDrive=tDrive, test=True)
        # curve fit to a sinusoid of arbitrary frequency, phase, magnitude
        times = data['times']
        A = f2.filt(data['ens'], dt=dt)
        X = np.dot(A, d2)
        freq0, phase0, mag0, base0 = fitSinusoid(times, X[:,0], freq, int(tTrans/dt), muFreq=muFreq, sigmaFreq=sigmaFreq, base=base)
        freq1, phase1, mag1, base1 = fitSinusoid(times, X[:,1], freq, int(tTrans/dt), muFreq=muFreq, sigmaFreq=sigmaFreq, base=base)
        s0 = base0+mag0*np.sin(times*2*np.pi*freq0+phase0)
        s1 = base1+mag1*np.sin(times*2*np.pi*freq1+phase1)
#         freqError0 = np.abs(freq-freq0)
#         freqError1 = np.abs(freq-freq1)
        rmseError0 = rmse(X[int(tTrans/dt):,0], s0[int(tTrans/dt):])
        rmseError1 = rmse(X[int(tTrans/dt):,1], s1[int(tTrans/dt):])
#         error0 = (1+freqError0) * rmseError0
#         error1 = (1+freqError1) * rmseError1
        error0 = rmseError0
        error1 = rmseError1
        errors[test] = (error0 + error1)/2
        fig, (ax, ax2) = plt.subplots(nrows=2, ncols=1, sharex=True)
        ax.plot(times, X[:,0], label="estimate (dim 0)")
        ax.plot(times, s0, label="target (dim 0)")
        ax.set(ylabel='state', title='freq=%.3f, rmse=%.3f'%(freq0, error0), xlim=((0, t)), ylim=((-1.2, 1.2)))
        ax.legend(loc='upper left')
        ax2.plot(times, X[:,1], label="estimate (dim 1)")
        ax2.plot(times, s1, label="target (dim 1)")
        ax2.set(xlabel='time', ylabel='state', title='freq=%.3f, rmse=%.3f'%(freq1, error1), xlim=((0, t)), ylim=((-1.2, 1.2)))
        ax2.legend(loc='upper left')
        sns.despine()
        fig.savefig("plots/oscillateNew_%s_test%s.pdf"%(neuron_type, test))
    plt.close('all')
    print('%s errors:'%neuron_type, errors)
    np.savez("data/oscillateNew_%s.npz"%neuron_type, d1=d1, tauRise1=tauRise1, tauFall1=tauFall1, e1=e1, d2=d2, tauRise2=tauRise2, tauFall2=tauFall2, e2=e2, errors=errors)
    return errors
def run(NPre=100,
        N=100,
        t=10,
        nTrain=10,
        nTest=3,
        nAttr=5,
        nEnc=10,
        dt=0.001,
        c=None,
        fPre=DoubleExp(1e-3, 1e-1),
        fNMDA=DoubleExp(10.6e-3, 285e-3),
        fGABA=DoubleExp(0.5e-3, 1.5e-3),
        fS=DoubleExp(2e-2, 1e-1),
        Tff=0.3,
        reg=1e-2,
        load=[],
        file=None):

    if not c: c = t
    if 0 in load:
        dPreA = np.load(file)['dPreA']
        dPreB = np.load(file)['dPreB']
        dPreC = np.load(file)['dPreC']
        dPreD = np.load(file)['dPreD']
    else:
        print('readout decoders for pre')
        spikesInptA = np.zeros((nTrain, int(t / dt), NPre))
        spikesInptB = np.zeros((nTrain, int(t / dt), NPre))
        spikesInptC = np.zeros((nTrain, int(t / dt), NPre))
        spikesInptD = np.zeros((nTrain, int(t / dt), NPre))
        targetsInptA = np.zeros((nTrain, int(t / dt), 1))
        targetsInptB = np.zeros((nTrain, int(t / dt), 1))
        targetsInptC = np.zeros((nTrain, int(t / dt), 1))
        targetsInptD = np.zeros((nTrain, int(t / dt), 1))
        for n in range(nTrain):
            stimA, _ = makeSignal(t, fPre, fNMDA, nAttr, dt=dt, seed=n)
            stimB = stimA
            stimC = lambda t: 0 if t < c else 1
            stimD = lambda t: 1e-1
            data = go(NPre=NPre,
                      N=N,
                      t=t,
                      dt=dt,
                      stimA=stimA,
                      stimB=stimB,
                      stimC=stimC,
                      stimD=stimD,
                      stage=None)
            spikesInptA[n] = data['preA']
            spikesInptB[n] = data['preB']
            spikesInptC[n] = data['preC']
            spikesInptD[n] = data['preD']
            targetsInptA[n] = fPre.filt(data['inptA'], dt=dt)
            targetsInptB[n] = fPre.filt(data['inptB'], dt=dt)
            targetsInptC[n] = fPre.filt(data['inptC'], dt=dt)
            targetsInptD[n] = fPre.filt(data['inptD'], dt=dt)
        dPreA, XA, YA, errorA = decodeNoF(spikesInptA,
                                          targetsInptA,
                                          nTrain,
                                          fPre,
                                          dt=dt,
                                          reg=reg)
        dPreB, XB, YB, errorB = decodeNoF(spikesInptB,
                                          targetsInptB,
                                          nTrain,
                                          fPre,
                                          dt=dt,
                                          reg=reg)
        dPreC, XC, YC, errorC = decodeNoF(spikesInptC,
                                          targetsInptC,
                                          nTrain,
                                          fPre,
                                          dt=dt,
                                          reg=reg)
        dPreD, XD, YD, errorD = decodeNoF(spikesInptD,
                                          targetsInptD,
                                          nTrain,
                                          fPre,
                                          dt=dt,
                                          reg=reg)
        np.savez("data/integrateAttractors.npz",
                 dPreA=dPreA,
                 dPreB=dPreB,
                 dPreC=dPreC,
                 dPreD=dPreD)
        times = np.arange(0, t * nTrain, dt)
        plotState(times, XA, YA, errorA, "integrateAttractors", "preA",
                  t * nTrain)
        plotState(times, XB, YB, errorB, "integrateAttractors", "preB",
                  t * nTrain)
        plotState(times, XC, YC, errorC, "integrateAttractors", "preC",
                  t * nTrain)
        plotState(times, XD, YD, errorD, "integrateAttractors", "preD",
                  t * nTrain)

    if 0 in load:
        ePreDEns = np.load(file)['ePreDEns']
    else:
        print("encoders for preD to ens")
        ePreDEns = np.zeros((NPre, N, 1))
        for n in range(nEnc):
            data = go(NPre=NPre,
                      N=N,
                      t=t,
                      dt=dt,
                      dPreD=dPreD,
                      ePreDEns=ePreDEns,
                      stage=0)
            ePreDEns = data['ePreDEns']
            plotActivity(t, dt, fS, data['times'], data['ens'], data['tarEns'],
                         "integrateAttractors", "preDEns")
            np.savez("data/integrateAttractors.npz",
                     dPreA=dPreA,
                     dPreB=dPreB,
                     dPreC=dPreC,
                     dPreD=dPreD,
                     ePreDEns=ePreDEns)

    if 1 in load:
        ePreAFdfw = np.load(file)['ePreAFdfw']
        ePreBEns = np.load(file)['ePreBEns']
        ePreCOff = np.load(file)['ePreCOff']
    else:
        print("encoders for [preA, preB, preC] to [fdfw, ens, off]")
        ePreAFdfw = np.zeros((NPre, N, 1))
        ePreBEns = np.zeros((NPre, N, 1))
        ePreCOff = np.zeros((NPre, N, 1))
        for n in range(nEnc):
            stimA, _ = makeSignal(t, fPre, fNMDA, nAttr, dt=dt, seed=n)
            stimB = stimA
            stimC = lambda t: 0 if t < c else 1
            data = go(NPre=NPre,
                      N=N,
                      t=t,
                      dt=dt,
                      stimA=stimA,
                      stimB=stimB,
                      stimC=stimB,
                      dPreA=dPreA,
                      dPreB=dPreB,
                      dPreC=dPreC,
                      dPreD=dPreD,
                      ePreAFdfw=ePreAFdfw,
                      ePreBEns=ePreBEns,
                      ePreCOff=ePreCOff,
                      ePreDEns=ePreDEns,
                      stage=1)
            ePreAFdfw = data['ePreAFdfw']
            ePreBEns = data['ePreBEns']
            ePreCOff = data['ePreCOff']
            plotActivity(t, dt, fS, data['times'], data['fdfw'],
                         data['tarFdfw'], "integrateAttractors", "preAFdfw")
            plotActivity(t, dt, fS, data['times'], data['ens'], data['tarEns'],
                         "integrateAttractors", "preBEns")
            plotActivity(t, dt, fS, data['times'], data['off'], data['tarOff'],
                         "integrateAttractors", "preCOff")
            np.savez("data/integrateAttractors.npz",
                     dPreA=dPreA,
                     dPreB=dPreB,
                     dPreC=dPreC,
                     dPreD=dPreD,
                     ePreDEns=ePreDEns,
                     ePreAFdfw=ePreAFdfw,
                     ePreBEns=ePreBEns,
                     ePreCOff=ePreCOff)

    if 2 in load:
        dFdfw = np.load(file)['dFdfw']
        dOff = np.load(file)['dOff']
    else:
        print('readout decoders for fdfw and off')
        spikesFdfw = np.zeros((nTrain, int(t / dt), N))
        spikesOff = np.zeros((nTrain, int(t / dt), N))
        targetsFdfw = np.zeros((nTrain, int(t / dt), 1))
        targetsOff = np.zeros((nTrain, int(t / dt), 1))
        for n in range(nTrain):
            stimA, stimB = makeSignal(t, fPre, fNMDA, nAttr, dt=dt, seed=n)
            stimC = lambda t: 0 if t < c else 1
            data = go(NPre=NPre,
                      N=N,
                      t=t,
                      dt=dt,
                      stimA=stimA,
                      stimB=stimB,
                      stimC=stimC,
                      dPreA=dPreA,
                      dPreB=dPreB,
                      dPreC=dPreC,
                      dPreD=dPreD,
                      ePreAFdfw=ePreAFdfw,
                      ePreBEns=ePreBEns,
                      ePreCOff=ePreCOff,
                      ePreDEns=ePreDEns,
                      stage=2)
            spikesFdfw[n] = data['fdfw']
            spikesOff[n] = data['off']
            targetsFdfw[n] = fNMDA.filt(Tff * fPre.filt(data['inptA']))
            # targetsFdfw[n] = fNMDA.filt(fPre.filt(data['inptB']))
            targetsOff[n] = fGABA.filt(fPre.filt(data['inptC']))
        dFdfw, X1, Y1, error1 = decodeNoF(spikesFdfw,
                                          targetsFdfw,
                                          nTrain,
                                          fNMDA,
                                          dt=dt,
                                          reg=reg)
        dOff, X2, Y2, error2 = decodeNoF(spikesOff,
                                         targetsOff,
                                         nTrain,
                                         fGABA,
                                         dt=dt,
                                         reg=reg)
        np.savez("data/integrateAttractors.npz",
                 dPreA=dPreA,
                 dPreB=dPreB,
                 dPreC=dPreC,
                 dPreD=dPreD,
                 dFdfw=dFdfw,
                 dOff=dOff,
                 ePreDEns=ePreDEns,
                 ePreAFdfw=ePreAFdfw,
                 ePreBEns=ePreBEns,
                 ePreCOff=ePreCOff)
        times = np.arange(0, t * nTrain, dt)
        plotState(times, X1, Y1, error1, "integrateAttractors", "fdfw",
                  t * nTrain)
        plotState(times, X2, Y2, error2, "integrateAttractors", "off",
                  t * nTrain)

    if 3 in load:
        eFdfwEns = np.load(file)['eFdfwEns']
    else:
        print("encoders for fdfw-to-ens")
        eFdfwEns = np.zeros((N, N, 1))
        for n in range(nEnc):
            stimA, stimB = makeSignal(t, fPre, fNMDA, nAttr, dt=dt, seed=n)
            data = go(NPre=NPre,
                      N=N,
                      t=t,
                      dt=dt,
                      Tff=Tff,
                      stimA=stimA,
                      stimB=stimB,
                      dPreA=dPreA,
                      dPreB=dPreB,
                      dPreD=dPreD,
                      dFdfw=dFdfw,
                      ePreAFdfw=ePreAFdfw,
                      ePreBEns=ePreBEns,
                      ePreDEns=ePreDEns,
                      eFdfwEns=eFdfwEns,
                      stage=3)
            eFdfwEns = data['eFdfwEns']
            plotActivity(t, dt, fS, data['times'], data['ens'], data['tarEns'],
                         "integrateAttractors", "fdfwEns")
        np.savez("data/integrateAttractors.npz",
                 dPreA=dPreA,
                 dPreB=dPreB,
                 dPreC=dPreC,
                 dPreD=dPreD,
                 dFdfw=dFdfw,
                 dOff=dOff,
                 ePreDEns=ePreDEns,
                 ePreAFdfw=ePreAFdfw,
                 ePreBEns=ePreBEns,
                 ePreCOff=ePreCOff,
                 eFdfwEns=eFdfwEns)

    if 4 in load:
        dEns = np.load(file)['dEns']
    else:
        print('readout decoders for ens')
        spikes = np.zeros((nTrain, int(t / dt), N))
        targets = np.zeros((nTrain, int(t / dt), 1))
        for n in range(nTrain):
            stimA, stimB = makeSignal(t, fPre, fNMDA, nAttr, dt=dt, seed=n)
            data = go(NPre=NPre,
                      N=N,
                      t=t,
                      dt=dt,
                      stimA=stimA,
                      stimB=stimB,
                      dPreA=dPreA,
                      dPreB=dPreB,
                      dPreD=dPreD,
                      dFdfw=dFdfw,
                      ePreAFdfw=ePreAFdfw,
                      ePreBEns=ePreBEns,
                      ePreDEns=ePreDEns,
                      eFdfwEns=eFdfwEns,
                      stage=4)
            spikes[n] = data['ens']
            # targets[n] = data['inptB']  # stimA rounded to nearest attractor
            targets[n] = fNMDA.filt(fPre.filt(
                data['inptB']))  # stimA rounded to nearest attractor
        dEns, X, Y, error = decodeNoF(spikes,
                                      targets,
                                      nTrain,
                                      fNMDA,
                                      dt=dt,
                                      reg=reg)
        np.savez("data/integrateAttractors.npz",
                 dPreA=dPreA,
                 dPreB=dPreB,
                 dPreC=dPreC,
                 dPreD=dPreD,
                 dFdfw=dFdfw,
                 dOff=dOff,
                 dEns=dEns,
                 ePreDEns=ePreDEns,
                 ePreAFdfw=ePreAFdfw,
                 ePreBEns=ePreBEns,
                 ePreCOff=ePreCOff,
                 eFdfwEns=eFdfwEns)
        times = np.arange(0, t * nTrain, dt)
        plotState(times, X, Y, error, "integrateAttractors", "ens", t * nTrain)

    if 5 in load:
        eEnsEns = np.load(file)['eEnsEns']
    else:
        print("encoders from ens to ens")
        eEnsEns = np.zeros((N, N, 1))
        for n in range(nEnc):
            stimA, stimB = makeSignal(t, fPre, fNMDA, nAttr, dt=dt, seed=n)
            data = go(NPre=NPre,
                      N=N,
                      t=t,
                      dt=dt,
                      stimA=stimA,
                      stimB=stimB,
                      dPreA=dPreA,
                      dPreB=dPreB,
                      dPreD=dPreD,
                      dFdfw=dFdfw,
                      dEns=dEns,
                      ePreAFdfw=ePreAFdfw,
                      ePreBEns=ePreBEns,
                      ePreDEns=ePreDEns,
                      eFdfwEns=eFdfwEns,
                      eEnsEns=eEnsEns,
                      stage=5)
            eEnsEns = data['eEnsEns']
            plotActivity(t, dt, fS, data['times'], data['ens'], data['tarEns'],
                         "integrateAttractors", "ensEns")
            np.savez("data/integrateAttractors.npz",
                     dPreA=dPreA,
                     dPreB=dPreB,
                     dPreC=dPreC,
                     dPreD=dPreD,
                     dFdfw=dFdfw,
                     dOff=dOff,
                     dEns=dEns,
                     ePreDEns=ePreDEns,
                     ePreAFdfw=ePreAFdfw,
                     ePreBEns=ePreBEns,
                     ePreCOff=ePreCOff,
                     eFdfwEns=eFdfwEns,
                     eEnsEns=eEnsEns)

    print("testing")
    eOffFdfw = -1e2 * np.ones((N, N, 1))
    vals = np.linspace(-1, 1, nTest)
    att = np.linspace(-1, 1, nAttr)
    for test in range(nTest):
        stimA = lambda t: vals[test] if t < c else 0
        stimB = lambda t: closestAttractor(vals[test], att)
        stimC = lambda t: 0 if t < c else 1
        DATest = lambda t: 0
        data = go(NPre=NPre,
                  N=N,
                  t=t,
                  dt=dt,
                  stimA=stimA,
                  stimB=stimB,
                  stimC=stimC,
                  dPreA=dPreA,
                  dPreC=dPreC,
                  dPreD=dPreD,
                  dFdfw=dFdfw,
                  dEns=dEns,
                  dOff=dOff,
                  ePreAFdfw=ePreAFdfw,
                  ePreCOff=ePreCOff,
                  ePreDEns=ePreDEns,
                  eFdfwEns=eFdfwEns,
                  eEnsEns=eEnsEns,
                  eOffFdfw=eOffFdfw,
                  stage=9,
                  c=c,
                  DA=DATest)
        aFdfw = fNMDA.filt(data['fdfw'])
        aEns = fNMDA.filt(data['ens'])
        aOff = fGABA.filt(data['off'])
        xhatFdfw = np.dot(aFdfw, dFdfw) / Tff
        xhatEns = np.dot(aEns, dEns)
        xhatOff = np.dot(aOff, dOff)
        xFdfw = fNMDA.filt(fPre.filt(data['inptA']))
        # xEns = xFdfw[int(c/dt)] * np.ones_like(data['times'])
        xEns = fNMDA.filt(fPre.filt(data['inptB']))
        # xEns = data['inptB']
        errorEns = rmse(xhatEns[int(c / dt):], xEns[int(c / dt):])
        fig, ax = plt.subplots()
        ax.plot(data['times'], xFdfw, alpha=0.5, label="input (filtered)")
        ax.plot(data['times'], xhatFdfw, alpha=0.5, label="fdfw")
        ax.plot(data['times'], xEns, label="target")
        ax.plot(data['times'],
                xhatEns,
                alpha=0.5,
                label="ens, rmse=%.3f" % errorEns)
        ax.plot(data['times'], xhatOff, alpha=0.5, label="off")
        ax.axvline(c, label="cutoff")
        ax.set(xlabel='time',
               ylabel='state',
               xlim=((0, t)),
               ylim=((-1.5, 1.5)))
        ax.legend(loc='upper left')
        sns.despine()
        fig.savefig("plots/integrateAttractors_testFlat%s.pdf" % test)
Exemplo n.º 27
0
def run(n_neurons=10000, neuron_type=LIF(), t_train=200, t=200, f=DoubleExp(1e-3, 1e-1), dt=0.001, dt_sample=0.003, tt=1.0, seed=0, smooth=30, reg=1e-1, penalty=0, df_evals=20, load_fd=False):

    d_ens = np.zeros((n_neurons, 3))
    f_ens = f

    if load_fd:
        load = np.load(load_fd)
        d_ens = load['d_ens']
        taus_ens = load['taus_ens']
        f_ens = DoubleExp(taus_ens[0], taus_ens[1])
        d_ens_gauss = load['d_ens_gauss']
    else:
        print('Optimizing ens filters and decoders')
        data = go(d_ens, f_ens, neuron_type=neuron_type, n_neurons=n_neurons, L=True, t=t_train, f=f, dt=dt, dt_sample=dt_sample, seed=seed)

        d_ens, f_ens, taus_ens = df_opt(data['x'], data['ens'], f, df_evals=df_evals, reg=reg, penalty=penalty, dt=dt_sample, dt_sample=dt_sample, name='lorenz_%s'%neuron_type)
        all_targets_gauss = gaussian_filter1d(data['x'], sigma=smooth, axis=0)
        all_spikes_gauss = gaussian_filter1d(data['ens'], sigma=smooth, axis=0)
        d_ens_gauss = nengo.solvers.LstsqL2(reg=reg)(all_spikes_gauss, all_targets_gauss)[0]
        np.savez('data/lorenz_%s_fd.npz'%neuron_type, d_ens=d_ens, taus_ens=taus_ens, d_ens_gauss=d_ens_gauss)
    
        f_times = np.arange(0, 1, 0.0001)
        fig, ax = plt.subplots()
        ax.plot(f_times, f.impulse(len(f_times), dt=0.0001), label=r"$f^x, \tau_1=%.3f, \tau_2=%.3f$"
            %(-1./f.poles[0], -1./f.poles[1]))
        ax.plot(f_times, f_ens.impulse(len(f_times), dt=0.0001), label=r"$f^{ens}, \tau_1=%.3f, \tau_2=%.3f, d: %s/%s$"
           %(-1./f_ens.poles[0], -1./f_ens.poles[1], np.count_nonzero(d_ens), n_neurons))
        ax.set(xlabel='time (seconds)', ylabel='impulse response', ylim=((0, 10)))
        ax.legend(loc='upper right')
        plt.tight_layout()
        plt.savefig("plots/lorenz_%s_filters_ens.pdf"%neuron_type)

        tar = f.filt(data['x'], dt=dt_sample)
        a_ens = f_ens.filt(data['ens'], dt=dt_sample)
        ens = np.dot(a_ens, d_ens)
        z_tar_peaks, _ = find_peaks(tar[:,2], height=0)  # gives time indices of z-component-peaks
        z_ens_peaks, _ = find_peaks(ens[:,2], height=0)

        fig = plt.figure()    
        ax = fig.add_subplot(121, projection='3d')
        ax2 = fig.add_subplot(122, projection='3d')
        ax.plot(*tar.T, linewidth=0.25)
#             ax.scatter(*tar[z_tar_peaks].T, color='r', s=1)
        ax2.plot(*ens.T, linewidth=0.25)
#             ax2.scatter(*ens[z_ens_peaks].T, color='r', s=1, marker='v')
        ax.set(xlabel="x", ylabel="y", zlabel="z", xlim=((-20, 20)), ylim=((-10, 30)), zlim=((0, 40)))
        ax.xaxis.pane.fill = False
        ax.yaxis.pane.fill = False
        ax.zaxis.pane.fill = False
        ax.xaxis.pane.set_edgecolor('w')
        ax.yaxis.pane.set_edgecolor('w')
        ax.zaxis.pane.set_edgecolor('w')
        ax.grid(False)
        ax2.set(xlabel="x", ylabel="y", zlabel="z", xlim=((-20, 20)), ylim=((-10, 30)), zlim=((0, 40)))
        ax2.xaxis.pane.fill = False
        ax2.yaxis.pane.fill = False
        ax2.zaxis.pane.fill = False
        ax2.xaxis.pane.set_edgecolor('w')
        ax2.yaxis.pane.set_edgecolor('w')
        ax2.zaxis.pane.set_edgecolor('w')
        ax2.grid(False)
        plt.savefig("plots/lorenz_%s_train_3D.pdf"%neuron_type)

        fig, (ax1, ax2, ax3) = plt.subplots(1, 3)
        ax1.plot(tar[:,0], tar[:,1], linestyle="--", linewidth=0.25)
        ax2.plot(tar[:,1], tar[:,2], linestyle="--", linewidth=0.25)
        ax3.plot(tar[:,0], tar[:,2], linestyle="--", linewidth=0.25)
#             ax2.scatter(tar[z_tar_peaks, 1], tar[z_tar_peaks, 2], s=3, color='r')
#             ax3.scatter(tar[z_tar_peaks, 0], tar[z_tar_peaks, 2], s=3, color='g')
        ax1.plot(ens[:,0], ens[:,1], linewidth=0.25)
        ax2.plot(ens[:,1], ens[:,2], linewidth=0.25)
        ax3.plot(ens[:,0], ens[:,2], linewidth=0.25)
#             ax2.scatter(ens[z_ens_peaks, 1], ens[z_ens_peaks, 2], s=3, color='r', marker='v')
#             ax3.scatter(ens[z_ens_peaks, 0], ens[z_ens_peaks, 2], s=3, color='g', marker='v')
        ax1.set(xlabel='x', ylabel='y')
        ax2.set(xlabel='y', ylabel='z')
        ax3.set(xlabel='x', ylabel='z')
        plt.tight_layout()
        plt.savefig("plots/lorenz_%s_train_pairwise.pdf"%neuron_type)
        plt.close('all')

        # Plot tent map and fit the data to a gaussian
        print('Plotting tent map')
        trans = int(tt/dt)
        tar_gauss = gaussian_filter1d(data['x'][trans:], sigma=smooth, axis=0)
        a_ens_gauss = gaussian_filter1d(data['ens'][trans:], sigma=smooth, axis=0)
        ens_gauss = np.dot(a_ens_gauss, d_ens_gauss)
        z_tar_peaks = find_peaks(tar_gauss[:,2], height=0)[0][1:]
        z_tar_values_horz = np.ravel(tar_gauss[z_tar_peaks, 2][:-1])
        z_tar_values_vert = np.ravel(tar_gauss[z_tar_peaks, 2][1:])
        z_ens_peaks = find_peaks(ens_gauss[:,2], height=0)[0][1:]
        z_ens_values_horz = np.ravel(ens_gauss[z_ens_peaks, 2][:-1])
        z_ens_values_vert = np.ravel(ens_gauss[z_ens_peaks, 2][1:])
#         def gaussian(x, mu, sigma, mag):
#             return mag * np.exp(-0.5*(np.square((x-mu)/sigma)))
#         p0 = [36, 2, 40]
#         param_ens, _ = curve_fit(gaussian, z_ens_values_horz, z_ens_values_vert, p0=p0)
#         param_tar, _ = curve_fit(gaussian, z_tar_values_horz, z_tar_values_vert, p0=p0)
#         horzs_tar = np.linspace(np.min(z_tar_values_horz), np.max(z_tar_values_horz), 100)
#         gauss_tar = gaussian(horzs_tar, param_tar[0], param_tar[1], param_tar[2])
#         horzs_ens = np.linspace(np.min(z_ens_values_horz), np.max(z_ens_values_horz), 100)
#         gauss_ens = gaussian(horzs_ens, param_ens[0], param_ens[1], param_ens[2])
#         error = entropy(gauss_ens, gauss_tar)
        fig, ax = plt.subplots()
        ax.scatter(z_tar_values_horz, z_tar_values_vert, alpha=0.5, color='r', label='target')
#         ax.plot(horzs_tar, gauss_tar, color='r', linestyle='--', label='target fit')
        ax.scatter(z_ens_values_horz, z_ens_values_vert, alpha=0.5, color='b', label='ens')
#         ax.plot(horzs_ens, gauss_ens, color='b', linestyle='--', label='ens fit')
        ax.set(xlabel=r'$\mathrm{max}_n (z)$', ylabel=r'$\mathrm{max}_{n+1} (z)$')#, title='error=%.5f'%error)
        plt.legend(loc='upper right')
        plt.savefig("plots/lorenz_%s_train_tent.pdf"%(neuron_type))        
        
    print("testing")
    data = go(d_ens, f_ens, neuron_type=neuron_type, n_neurons=n_neurons, L=False, t=t, f=f, dt=dt, dt_sample=dt_sample, seed=seed)

    tar = f.filt(data['x'], dt=dt_sample)
    a_ens = f_ens.filt(data['ens'], dt=dt_sample)
    ens = np.dot(a_ens, d_ens)
    z_tar_peaks, _ = find_peaks(tar[:,2], height=0)  # gives time indices of z-component-peaks
    z_ens_peaks, _ = find_peaks(ens[:,2], height=0)

    fig = plt.figure()    
    ax = fig.add_subplot(121, projection='3d')
    ax2 = fig.add_subplot(122, projection='3d')
    ax.plot(*tar.T, linewidth=0.25)
#             ax.scatter(*tar[z_tar_peaks].T, color='r', s=1)
    ax2.plot(*ens.T, linewidth=0.25)
#             ax2.scatter(*ens[z_ens_peaks].T, color='r', s=1, marker='v')
    ax.set(xlabel="x", ylabel="y", zlabel="z", xlim=((-20, 20)), ylim=((-10, 30)), zlim=((0, 40)))
    ax.xaxis.pane.fill = False
    ax.yaxis.pane.fill = False
    ax.zaxis.pane.fill = False
    ax.xaxis.pane.set_edgecolor('w')
    ax.yaxis.pane.set_edgecolor('w')
    ax.zaxis.pane.set_edgecolor('w')
    ax.grid(False)
    ax2.set(xlabel="x", ylabel="y", zlabel="z", xlim=((-20, 20)), ylim=((-10, 30)), zlim=((0, 40)))
    ax2.xaxis.pane.fill = False
    ax2.yaxis.pane.fill = False
    ax2.zaxis.pane.fill = False
    ax2.xaxis.pane.set_edgecolor('w')
    ax2.yaxis.pane.set_edgecolor('w')
    ax2.zaxis.pane.set_edgecolor('w')
    ax2.grid(False)
    plt.savefig("plots/lorenz_%s_test_3D.pdf"%neuron_type)

    fig, (ax1, ax2, ax3) = plt.subplots(1, 3)
    ax1.plot(tar[:,0], tar[:,1], linestyle="--", linewidth=0.25)
    ax2.plot(tar[:,1], tar[:,2], linestyle="--", linewidth=0.25)
    ax3.plot(tar[:,0], tar[:,2], linestyle="--", linewidth=0.25)
#             ax2.scatter(tar[z_tar_peaks, 1], tar[z_tar_peaks, 2], s=3, color='r')
#             ax3.scatter(tar[z_tar_peaks, 0], tar[z_tar_peaks, 2], s=3, color='g')
    ax1.plot(ens[:,0], ens[:,1], linewidth=0.25)
    ax2.plot(ens[:,1], ens[:,2], linewidth=0.25)
    ax3.plot(ens[:,0], ens[:,2], linewidth=0.25)
#             ax2.scatter(ens[z_ens_peaks, 1], ens[z_ens_peaks, 2], s=3, color='r', marker='v')
#             ax3.scatter(ens[z_ens_peaks, 0], ens[z_ens_peaks, 2], s=3, color='g', marker='v')
    ax1.set(xlabel='x', ylabel='y')
    ax2.set(xlabel='y', ylabel='z')
    ax3.set(xlabel='x', ylabel='z')
    plt.tight_layout()
    plt.savefig("plots/lorenz_%s_test_pairwise.pdf"%neuron_type)
    plt.close('all')

    # Plot tent map and fit the data to a gaussian
    print('Plotting tent map')
    trans = int(tt/dt)
    tar_gauss = gaussian_filter1d(data['x'][trans:], sigma=smooth, axis=0)
    a_ens_gauss = gaussian_filter1d(data['ens'][trans:], sigma=smooth, axis=0)
    ens_gauss = np.dot(a_ens_gauss, d_ens_gauss)
    z_tar_peaks = find_peaks(tar_gauss[:,2], height=0)[0][1:]
    z_tar_values_horz = np.ravel(tar_gauss[z_tar_peaks, 2][:-1])
    z_tar_values_vert = np.ravel(tar_gauss[z_tar_peaks, 2][1:])
    z_ens_peaks = find_peaks(ens_gauss[:,2], height=0)[0][1:]
    z_ens_values_horz = np.ravel(ens_gauss[z_ens_peaks, 2][:-1])
    z_ens_values_vert = np.ravel(ens_gauss[z_ens_peaks, 2][1:])
#     def gaussian(x, mu, sigma, mag):
#         return mag * np.exp(-0.5*(np.square((x-mu)/sigma)))
#     p0 = [36, 2, 40]
#     param_ens, _ = curve_fit(gaussian, z_ens_values_horz, z_ens_values_vert, p0=p0)
#     param_tar, _ = curve_fit(gaussian, z_tar_values_horz, z_tar_values_vert, p0=p0)
#     horzs_tar = np.linspace(np.min(z_tar_values_horz), np.max(z_tar_values_horz), 100)
#     gauss_tar = gaussian(horzs_tar, param_tar[0], param_tar[1], param_tar[2])
#     horzs_ens = np.linspace(np.min(z_ens_values_horz), np.max(z_ens_values_horz), 100)
#     gauss_ens = gaussian(horzs_ens, param_ens[0], param_ens[1], param_ens[2])
#     error = entropy(gauss_ens, gauss_tar)
    fig, ax = plt.subplots()
    ax.scatter(z_tar_values_horz, z_tar_values_vert, alpha=0.5, color='r', label='target')
#     ax.plot(horzs_tar, gauss_tar, color='r', linestyle='--', label='target fit')
    ax.scatter(z_ens_values_horz, z_ens_values_vert, alpha=0.5, color='b', label='ens')
#     ax.plot(horzs_ens, gauss_ens, color='b', linestyle='--', label='ens fit')
    ax.set(xlabel=r'$\mathrm{max}_n (z)$', ylabel=r'$\mathrm{max}_{n+1} (z)$')#, title='error=%.5f'%error)
    plt.legend(loc='upper right')
    plt.savefig("plots/lorenz_%s_test_tent.pdf"%(neuron_type))        
def go(NPre=100,
       N=100,
       t=10,
       c=None,
       seed=1,
       dt=0.001,
       tTrans=0.01,
       stage=None,
       alpha=3e-7,
       eMax=1e-1,
       Tff=0.3,
       fPre=DoubleExp(1e-3, 1e-1),
       fNMDA=DoubleExp(10.6e-3, 285e-3),
       fGABA=DoubleExp(0.5e-3, 1.5e-3),
       fS=DoubleExp(1e-3, 1e-1),
       dPreA=None,
       dPreB=None,
       dPreC=None,
       dPreD=None,
       dFdfw=None,
       dEns=None,
       dOff=None,
       ePreAFdfw=None,
       ePreBEns=None,
       ePreCOff=None,
       eFdfwEns=None,
       eEnsEns=None,
       ePreDEns=None,
       eOffFdfw=None,
       stimA=lambda t: 0,
       stimB=lambda t: 0,
       stimC=lambda t: 0,
       stimD=lambda t: 0,
       DA=lambda t: 0):

    if not c: c = t
    with nengo.Network(seed=seed) as model:
        inptA = nengo.Node(stimA)
        inptB = nengo.Node(stimB)
        inptC = nengo.Node(stimC)
        inptD = nengo.Node(stimD)
        preA = nengo.Ensemble(NPre, 1, max_rates=Uniform(30, 30), seed=seed)
        preB = nengo.Ensemble(NPre, 1, max_rates=Uniform(30, 30), seed=seed)
        preC = nengo.Ensemble(NPre, 1, max_rates=Uniform(30, 30), seed=seed)
        preD = nengo.Ensemble(NPre, 1, max_rates=Uniform(30, 30), seed=seed)
        fdfw = nengo.Ensemble(N,
                              1,
                              neuron_type=Bio("Pyramidal", DA=DA),
                              seed=seed)
        ens = nengo.Ensemble(N,
                             1,
                             neuron_type=Bio("Pyramidal", DA=DA),
                             seed=seed + 1)
        off = nengo.Ensemble(N,
                             1,
                             neuron_type=Bio("Interneuron", DA=DA),
                             seed=seed + 3)
        tarFdfw = nengo.Ensemble(N,
                                 1,
                                 max_rates=Uniform(30, 30),
                                 intercepts=Uniform(-0.8, 0.8),
                                 seed=seed)
        tarEns = nengo.Ensemble(N,
                                1,
                                max_rates=Uniform(30, 30),
                                intercepts=Uniform(-0.8, 0.8),
                                seed=seed + 1)
        tarOff = nengo.Ensemble(N,
                                1,
                                max_rates=Uniform(30, 30),
                                intercepts=Uniform(0.2, 0.8),
                                encoders=Choice([[1]]),
                                seed=seed + 3)
        cA = nengo.Connection(inptA, preA, synapse=None, seed=seed)
        cB = nengo.Connection(inptB, preB, synapse=None, seed=seed)
        cC = nengo.Connection(inptC, preC, synapse=None, seed=seed)
        cD = nengo.Connection(inptD, preD, synapse=None, seed=seed)
        pInptA = nengo.Probe(inptA, synapse=None)
        pInptB = nengo.Probe(inptB, synapse=None)
        pInptC = nengo.Probe(inptC, synapse=None)
        pInptD = nengo.Probe(inptD, synapse=None)
        pPreA = nengo.Probe(preA.neurons, synapse=None)
        pPreB = nengo.Probe(preB.neurons, synapse=None)
        pPreC = nengo.Probe(preC.neurons, synapse=None)
        pPreD = nengo.Probe(preD.neurons, synapse=None)
        pFdfw = nengo.Probe(fdfw.neurons, synapse=None)
        pTarFdfw = nengo.Probe(tarFdfw.neurons, synapse=None)
        pEns = nengo.Probe(ens.neurons, synapse=None)
        pTarEns = nengo.Probe(tarEns.neurons, synapse=None)
        pOff = nengo.Probe(off.neurons, synapse=None)
        pTarOff = nengo.Probe(tarOff.neurons, synapse=None)
        if stage == 0:
            nengo.Connection(preD, tarEns, synapse=fPre, seed=seed)
            c0 = nengo.Connection(preD,
                                  ens,
                                  synapse=fPre,
                                  solver=NoSolver(dPreD),
                                  seed=seed)
            learnEncoders(c0,
                          tarEns,
                          fS,
                          alpha=10 * alpha,
                          eMax=10 * eMax,
                          tTrans=tTrans)
        if stage == 1:
            nengo.Connection(inptA, tarFdfw, synapse=fPre, seed=seed)
            nengo.Connection(inptB, tarEns, synapse=fPre, seed=seed)
            nengo.Connection(inptC, tarOff, synapse=fPre, seed=seed)
            c0 = nengo.Connection(preD,
                                  ens,
                                  synapse=fPre,
                                  solver=NoSolver(dPreD),
                                  seed=seed)
            c1 = nengo.Connection(preA,
                                  fdfw,
                                  synapse=fPre,
                                  solver=NoSolver(dPreA),
                                  seed=seed)
            c2 = nengo.Connection(preB,
                                  ens,
                                  synapse=fPre,
                                  solver=NoSolver(dPreB),
                                  seed=seed)
            c3 = nengo.Connection(preC,
                                  off,
                                  synapse=fPre,
                                  solver=NoSolver(dPreC),
                                  seed=seed)
            learnEncoders(c1,
                          tarFdfw,
                          fS,
                          alpha=3 * alpha,
                          eMax=3 * eMax,
                          tTrans=tTrans)
            learnEncoders(c2,
                          tarEns,
                          fS,
                          alpha=10 * alpha,
                          eMax=10 * eMax,
                          tTrans=tTrans)
            learnEncoders(c3,
                          tarOff,
                          fS,
                          alpha=3 * alpha,
                          eMax=3 * eMax,
                          tTrans=tTrans)
        if stage == 2:
            c1 = nengo.Connection(preA,
                                  fdfw,
                                  synapse=fPre,
                                  solver=NoSolver(dPreA),
                                  seed=seed)
            c2 = nengo.Connection(preC,
                                  off,
                                  synapse=fPre,
                                  solver=NoSolver(dPreC),
                                  seed=seed)
        if stage == 3:
            cB.synapse = fNMDA
            ff = nengo.Ensemble(1, 1, neuron_type=nengo.Direct())
            fb = nengo.Ensemble(1, 1, neuron_type=nengo.Direct())
            nengo.Connection(inptA, ff, synapse=fPre, seed=seed)
            nengo.Connection(inptB, fb, synapse=fNMDA, seed=seed)
            nengo.Connection(fb, tarEns, synapse=fPre, seed=seed)
            nengo.Connection(ff,
                             tarEns,
                             synapse=fNMDA,
                             transform=Tff,
                             seed=seed)
            c0 = nengo.Connection(preD,
                                  ens,
                                  synapse=fPre,
                                  solver=NoSolver(dPreD),
                                  seed=seed)
            c1 = nengo.Connection(preA,
                                  fdfw,
                                  synapse=fPre,
                                  solver=NoSolver(dPreA),
                                  seed=seed)
            c2 = nengo.Connection(preB,
                                  ens,
                                  synapse=fPre,
                                  solver=NoSolver(dPreB),
                                  seed=seed)
            c3 = nengo.Connection(fdfw,
                                  ens,
                                  synapse=NMDA(),
                                  solver=NoSolver(dFdfw),
                                  seed=seed)
            learnEncoders(c3,
                          tarEns,
                          fS,
                          alpha=3 * alpha,
                          eMax=3 * eMax,
                          tTrans=tTrans)
        if stage == 4:
            cB.synapse = fNMDA
            c0 = nengo.Connection(preD,
                                  ens,
                                  synapse=fPre,
                                  solver=NoSolver(dPreD),
                                  seed=seed)
            c1 = nengo.Connection(preA,
                                  fdfw,
                                  synapse=fPre,
                                  solver=NoSolver(dPreA),
                                  seed=seed)
            c2 = nengo.Connection(preB,
                                  ens,
                                  synapse=fPre,
                                  solver=NoSolver(dPreB),
                                  seed=seed)
            c3 = nengo.Connection(fdfw,
                                  ens,
                                  synapse=NMDA(),
                                  solver=NoSolver(dFdfw),
                                  seed=seed)
        if stage == 5:
            preB2 = nengo.Ensemble(NPre,
                                   1,
                                   max_rates=Uniform(30, 30),
                                   seed=seed)
            ens2 = nengo.Ensemble(N,
                                  1,
                                  neuron_type=Bio("Pyramidal", DA=DA),
                                  seed=seed + 1)
            ens3 = nengo.Ensemble(N,
                                  1,
                                  neuron_type=Bio("Pyramidal", DA=DA),
                                  seed=seed + 1)
            nengo.Connection(inptB, preB2, synapse=fNMDA, seed=seed)
            c0a = nengo.Connection(preD,
                                   ens,
                                   synapse=fPre,
                                   solver=NoSolver(dPreD),
                                   seed=seed)
            c0b = nengo.Connection(preD,
                                   ens2,
                                   synapse=fPre,
                                   solver=NoSolver(dPreD),
                                   seed=seed)
            c0c = nengo.Connection(preD,
                                   ens3,
                                   synapse=fPre,
                                   solver=NoSolver(dPreD),
                                   seed=seed)
            c1 = nengo.Connection(preA,
                                  fdfw,
                                  synapse=fPre,
                                  solver=NoSolver(dPreA),
                                  seed=seed)
            c2 = nengo.Connection(fdfw,
                                  ens,
                                  synapse=NMDA(),
                                  solver=NoSolver(dFdfw),
                                  seed=seed)
            c3 = nengo.Connection(preB,
                                  ens2,
                                  synapse=fPre,
                                  solver=NoSolver(dPreB),
                                  seed=seed)
            c4 = nengo.Connection(ens2,
                                  ens,
                                  synapse=NMDA(),
                                  solver=NoSolver(dEns),
                                  seed=seed)
            c5 = nengo.Connection(fdfw,
                                  ens3,
                                  synapse=NMDA(),
                                  solver=NoSolver(dFdfw),
                                  seed=seed)
            c6 = nengo.Connection(preB2,
                                  ens3,
                                  synapse=fPre,
                                  solver=NoSolver(dPreB),
                                  seed=seed)
            learnEncoders(c4, ens3, fS, alpha=alpha, eMax=eMax, tTrans=tTrans)
            pTarEns = nengo.Probe(ens3.neurons, synapse=None)
        if stage == 9:
            c0 = nengo.Connection(preD,
                                  ens,
                                  synapse=fPre,
                                  solver=NoSolver(dPreD),
                                  seed=seed)
            c1 = nengo.Connection(preA,
                                  fdfw,
                                  synapse=fPre,
                                  solver=NoSolver(dPreA),
                                  seed=seed)
            c2 = nengo.Connection(fdfw,
                                  ens,
                                  synapse=NMDA(),
                                  solver=NoSolver(dFdfw),
                                  seed=seed)
            c3 = nengo.Connection(ens,
                                  ens,
                                  synapse=NMDA(),
                                  solver=NoSolver(dEns),
                                  seed=seed)
            c6 = nengo.Connection(preC,
                                  off,
                                  synapse=fPre,
                                  solver=NoSolver(dPreC),
                                  seed=seed)
            c7 = nengo.Connection(off,
                                  fdfw,
                                  synapse=GABA(),
                                  solver=NoSolver(dOff),
                                  seed=seed)

    with nengo.Simulator(model, seed=seed, dt=dt, progress_bar=False) as sim:
        if stage == 0:
            setWeights(c0, dPreD, ePreDEns)
        if stage == 1:
            setWeights(c0, dPreD, ePreDEns)
            setWeights(c1, dPreA, ePreAFdfw)
            setWeights(c2, dPreB, ePreBEns)
            setWeights(c3, dPreC, ePreCOff)
        if stage == 2:
            setWeights(c1, dPreA, ePreAFdfw)
            setWeights(c2, dPreC, ePreCOff)
        if stage == 3:
            setWeights(c0, dPreD, ePreDEns)
            setWeights(c1, dPreA, ePreAFdfw)
            setWeights(c2, dPreB, ePreBEns)
            setWeights(c3, dFdfw, eFdfwEns)
        if stage == 4:
            setWeights(c0, dPreD, ePreDEns)
            setWeights(c1, dPreA, ePreAFdfw)
            setWeights(c2, dPreB, ePreBEns)
            setWeights(c3, dFdfw, eFdfwEns)
        if stage == 5:
            setWeights(c0a, dPreD, ePreDEns)
            setWeights(c0b, dPreD, ePreDEns)
            setWeights(c0c, dPreD, ePreDEns)
            setWeights(c1, dPreA, ePreAFdfw)
            setWeights(c2, dFdfw, eFdfwEns)
            setWeights(c3, dPreB, ePreBEns)
            setWeights(c4, dEns, eEnsEns)
            setWeights(c5, dFdfw, eFdfwEns)
            setWeights(c6, dPreB, ePreBEns)
        if stage == 9:
            setWeights(c0, dPreD, ePreDEns)
            setWeights(c1, dPreA, ePreAFdfw)
            setWeights(c2, dFdfw, eFdfwEns)
            setWeights(c3, dEns, eEnsEns)
            setWeights(c6, dPreC, ePreCOff)
            setWeights(c7, dOff, eOffFdfw)

        neuron.h.init()
        sim.run(t, progress_bar=True)
        reset_neuron(sim, model)

    ePreDEns = c0.e if stage == 0 else ePreDEns
    ePreAFdfw = c1.e if stage == 1 else ePreAFdfw
    ePreBEns = c2.e if stage == 1 else ePreBEns
    ePreCOff = c3.e if stage == 1 else ePreCOff
    eFdfwEns = c3.e if stage == 3 else eFdfwEns
    eEnsEns = c4.e if stage == 5 else eEnsEns

    return dict(
        times=sim.trange(),
        inptA=sim.data[pInptA],
        inptB=sim.data[pInptB],
        inptC=sim.data[pInptC],
        inptD=sim.data[pInptD],
        preA=sim.data[pPreA],
        preB=sim.data[pPreB],
        preC=sim.data[pPreC],
        preD=sim.data[pPreD],
        fdfw=sim.data[pFdfw],
        ens=sim.data[pEns],
        off=sim.data[pOff],
        tarFdfw=sim.data[pTarFdfw],
        tarEns=sim.data[pTarEns],
        tarOff=sim.data[pTarOff],
        ePreAFdfw=ePreAFdfw,
        ePreBEns=ePreBEns,
        ePreCOff=ePreCOff,
        ePreDEns=ePreDEns,
        eFdfwEns=eFdfwEns,
        eEnsEns=eEnsEns,
        eOffFdfw=eOffFdfw,
    )
Exemplo n.º 29
0
def run(t=60, tTrans=30, dt=1e-4, maxRate=30, intercept=-0.25, f=DoubleExp(1e-3, 3e-2), fS=DoubleExp(1e-2, 1e-1), nBins=21):

    print('training readout decoders for pre')
    stim = makeSignal(t, f, dt=0.001, seed=0)
    data = go(t=t, m=Uniform(maxRate, maxRate), i=Uniform(intercept, intercept), dt=0.001, f=f, fS=fS, stim=stim)
    spikes = np.array([data['pre']])
    targets = np.array([f.filt(data['inpt'], dt=0.001)])
    d1, f1, tauRise1, tauFall1, X, Y, error = decode(spikes, targets, 1, dt=0.001, name="tuneNew")

    print("training encoders")
    e1 = np.zeros((100, 1, 1))
    stim = makeSignal(t, f, dt=dt, seed=0)
    data = go(d1=d1, e1=e1, f1=f1, t=t, dt=dt, f=f, fS=fS, stim=stim, l1=True)
    x = fS.filt(data['inpt'], dt=dt)
    times = data['times']
    aLif = fS.filt(data['lif'], dt=dt)
    aWilson = fS.filt(data['wilson'], dt=dt)
    aBio = fS.filt(data['bio'], dt=dt)
    
    fig, (ax, ax2) = plt.subplots(nrows=2, ncols=1, sharex=True)
    ax.plot(times, x, label="input")
    ax.axhline(intercept, color='k', alpha=0.5, label="intercept")
    ax2.plot(times, aLif, alpha=0.5, label='LIF')
    ax2.plot(times, aWilson, alpha=0.5, label='Wilson')
    ax2.plot(times, aBio, alpha=0.5, label='bio')
    ax.set(xlim=((0, tTrans)), ylim=((-1, 1)), ylabel=r"$\mathbf{x}$(t)")
    ax2.set(xlim=((0, tTrans)), ylim=((0, 40)), xlabel='time (s)', ylabel=r"$a(t)$")
    ax.legend(loc='upper left')
    ax2.legend(loc='upper left')
    plt.savefig('plots/tuneNew.pdf')
    plt.close('all')

    x = fS.filt(data['inpt'], dt=dt)[int(tTrans/dt):]
    times = data['times'][int(tTrans/dt):]
    aLif = fS.filt(data['lif'], dt=dt)[int(tTrans/dt):]
    aWilson = fS.filt(data['wilson'], dt=dt)[int(tTrans/dt):]
    aBio = fS.filt(data['bio'], dt=dt)[int(tTrans/dt):]
    bins = np.linspace(-1, 1, nBins)
    bLif, bWilson, bBio = [], [], []
    for b in range(len(bins)):
        bLif.append([])
        bWilson.append([])
        bBio.append([])
    for t in range(len(times)):
        idx = (np.abs(bins - x[t])).argmin()
        bLif[idx].append(aLif[t][0])
        bWilson[idx].append(aWilson[t][0])
        bBio[idx].append(aBio[t][0])
    mLif, mBio, mWilson = np.zeros_like(bins), np.zeros_like(bins), np.zeros_like(bins)
    ci1Lif, ci1Bio, ci1Wilson = np.zeros_like(bins), np.zeros_like(bins), np.zeros_like(bins)
    ci2Lif, ci2Bio, ci2Wilson = np.zeros_like(bins), np.zeros_like(bins), np.zeros_like(bins)
    for b in range(len(bins)):
        mLif[b] = np.mean(bLif[b])
        mWilson[b] = np.mean(bWilson[b])
        mBio[b] = np.mean(bBio[b])
        ci1Lif[b] = sns.utils.ci(np.array(bLif[b]), which=95)[0]
        ci1Wilson[b] = sns.utils.ci(np.array(bWilson[b]), which=95)[0]
        ci1Bio[b] = sns.utils.ci(np.array(bBio[b]), which=95)[0]
        ci2Lif[b] = sns.utils.ci(np.array(bLif[b]), which=95)[1]
        ci2Wilson[b] = sns.utils.ci(np.array(bWilson[b]), which=95)[1]
        ci2Bio[b] = sns.utils.ci(np.array(bBio[b]), which=95)[1]

    fig, ax = plt.subplots()
    ax.plot(bins, mLif, label="LIF")
    ax.fill_between(bins, ci1Lif, ci2Lif, alpha=0.1)
    ax.plot(bins, mWilson, label="Wilson")
    ax.fill_between(bins, ci1Wilson, ci2Wilson, alpha=0.1)
    ax.plot(bins, mBio, label="Bio")
    ax.fill_between(bins, ci1Bio, ci2Bio, alpha=0.1)
    ax.axhline(maxRate, color='k', linestyle="--", label="target max rate")
    ax.axvline(intercept, color='k', label="target intercept")
    ax.set(xlim=((-1, 1)), ylim=((0, maxRate+1)), xlabel=r"$\mathbf{x}$", ylabel="firing rate (Hz)")
    ax.legend()
    fig.savefig("plots/tuneNew2.pdf")
Exemplo n.º 30
0
def run(NPre=100,
        N=100,
        t=10,
        c=None,
        nTrain=10,
        nTest=5,
        dt=0.001,
        neuron_type=LIF(),
        fPre=DoubleExp(1e-3, 1e-1),
        fNMDA=DoubleExp(10.6e-3, 285e-3),
        reg=1e-2,
        load=[],
        file=None,
        neg=True):

    if not c: c = t
    if 0 in load:
        dPre = np.load(file)['dPre']
    else:
        print('readout decoders for pre')
        spikesInpt = np.zeros((nTrain, int(t / dt), NPre))
        targetsInpt = np.zeros((nTrain, int(t / dt), 1))
        for n in range(nTrain):
            stim = makeSignal(t, fPre, fNMDA, seed=n)
            data = go(NPre=NPre,
                      N=N,
                      t=t,
                      fPre=fPre,
                      fNMDA=fNMDA,
                      neuron_type=neuron_type,
                      stim=stim)
            spikesInpt[n] = data['pre']
            targetsInpt[n] = fPre.filt(data['inpt'], dt=dt)
        dPre, X, Y, error = decodeNoF(spikesInpt,
                                      targetsInpt,
                                      nTrain,
                                      fPre,
                                      reg=reg)
        np.savez("data/diffMemory_%s.npz" % neuron_type, dPre=dPre)
        times = np.arange(0, t * nTrain, 0.001)
        plotState(times, X, Y, error, "diffMemory_%s" % neuron_type, "pre",
                  t * nTrain)

    if 1 in load:
        dDiff = np.load(file)['dDiff']
    else:
        print('readout decoders for diff')
        spikes = np.zeros((nTrain, int(t / dt), N))
        targets = np.zeros((nTrain, int(t / dt), 1))
        for n in range(nTrain):
            stim = makeSignal(t, fPre, fNMDA, seed=n)
            data = go(NPre=NPre,
                      N=N,
                      t=t,
                      fPre=fPre,
                      fNMDA=fNMDA,
                      neuron_type=neuron_type,
                      stim=stim,
                      dPre=dPre)
            spikes[n] = data['diff']
            targets[n] = fNMDA.filt(fPre.filt(data['inpt'], dt=dt), dt=dt)
        dDiff, X, Y, error = decodeNoF(spikes, targets, nTrain, fNMDA, reg=reg)
        np.savez("data/diffMemory_%s.npz" % neuron_type,
                 dPre=dPre,
                 dDiff=dDiff)
        times = np.arange(0, t * nTrain, dt)
        plotState(times, X, Y, error, "diffMemory_%s" % neuron_type, "diff",
                  t * nTrain)

    if 2 in load:
        dEns = np.load(file)['dEns']
        dNeg = np.load(file)['dNeg']
    else:
        print('readout decoders for ens')
        spikes = np.zeros((nTrain, int(t / dt), N))
        targets = np.zeros((nTrain, int(t / dt), 1))
        for n in range(nTrain):
            stim = makeSignal(t, fPre, fNMDA, seed=n)
            # data = go(NPre=NPre, N=N, t=t, fPre=fPre, fNMDA=fNMDA, neuron_type=neuron_type, stim=stim, dPre=dPre, dDiff=dDiff)
            data = go(NPre=NPre,
                      N=N,
                      t=t,
                      fPre=fPre,
                      fNMDA=fNMDA,
                      neuron_type=neuron_type,
                      stim=stim,
                      dPre=dPre,
                      dDiff=dDiff,
                      train=True,
                      Tff=0.3)
            spikes[n] = data['ens']
            # targets[n] = fNMDA.filt(fNMDA.filt(fPre.filt(data['inpt'], dt=dt), dt=dt), dt=dt)
            targets[n] = fNMDA.filt(fPre.filt(data['intg'], dt=dt), dt=dt)
        dEns, X, Y, error = decodeNoF(spikes, targets, nTrain, fNMDA, reg=reg)
        dNeg = -np.array(dEns)
        np.savez("data/diffMemory_%s.npz" % neuron_type,
                 dPre=dPre,
                 dDiff=dDiff,
                 dEns=dEns,
                 dNeg=dNeg)
        times = np.arange(0, t * nTrain, dt)
        plotState(times, X, Y, error, "diffMemory_%s" % neuron_type, "ens",
                  t * nTrain)

    print('testing')
    vals = np.linspace(-1, 1, nTest)
    for test in range(nTest):
        if neg:
            stim = lambda t: vals[test] if t < c else 0
            data = go(NPre=NPre,
                      N=N,
                      t=t,
                      fPre=fPre,
                      fNMDA=fNMDA,
                      neuron_type=neuron_type,
                      stim=stim,
                      dPre=dPre,
                      dDiff=dDiff,
                      dEns=dEns,
                      dNeg=dNeg,
                      c=c)
        else:
            stim = makeSignal(t, fPre, fNMDA, seed=test)
            data = go(NPre=NPre,
                      N=N,
                      t=t,
                      fPre=fPre,
                      fNMDA=fNMDA,
                      neuron_type=neuron_type,
                      stim=stim,
                      dPre=dPre,
                      dDiff=dDiff,
                      dEns=dEns,
                      dNeg=None,
                      c=None,
                      Tff=0.3)
        aDiff = fNMDA.filt(fPre.filt(data['diff'], dt=dt), dt=dt)
        aEns = fNMDA.filt(fNMDA.filt(fPre.filt(data['ens'], dt=dt), dt=dt),
                          dt=dt)
        xhatDiff = np.dot(aDiff, dDiff)
        xhatEns = np.dot(aEns, dEns)
        u = fNMDA.filt(fPre.filt(data['inpt'], dt=dt), dt=dt)
        u2 = fNMDA.filt(fNMDA.filt(fPre.filt(data['inpt'], dt=dt), dt=dt),
                        dt=dt)
        x = fNMDA.filt(fNMDA.filt(fPre.filt(data['intg'], dt=dt), dt=dt),
                       dt=dt)
        error = rmse(xhatEns, x)
        fig, ax = plt.subplots()
        if neg:
            ax.plot(data['times'], u2, alpha=0.5, label="input (delayed)")
            ax.axvline(c, label="cutoff")
        else:
            # ax.plot(data['times'], 0.3*u2, alpha=0.5, label="input")
            ax.plot(data['times'], x, alpha=0.5, label="integral")
        # ax.plot(data['times'], xhatDiff, label="diff")
        ax.plot(data['times'], xhatEns, label="ens")
        ax.set(xlabel='time',
               ylabel='state',
               title="rmse=%.3f" % error,
               xlim=((0, t)),
               ylim=((-1, 1)))
        ax.legend(loc='upper left')
        sns.despine()
        fig.savefig("plots/diffMemory_%s_test%s.pdf" % (neuron_type, test))