Esempio n. 1
0
    #---------------------------------------------------------------------------

    return trial

# Performance measure: two-alternative forced choice
performance = tasktools.performance_2afc

# Terminate training when psychometric performance exceeds 85%
def terminate(performance_history):
    return np.mean(performance_history[-5:]) > 85

# Validation dataset size
n_validation = 100*(nconditions + 1)

#///////////////////////////////////////////////////////////////////////////////

if __name__ == '__main__':
    # Train model
    model = Model(Nin=Nin, N=N, Nout=Nout, ei=ei, Cout=Cout,
                  generate_trial=generate_trial,
                  performance=performance, terminate=terminate,
                  n_validation=n_validation)
    model.train('savefile.pkl')

    # Run the trained network with 16*3.2% = 51.2% coherence for choice 1
    rnn        = RNN('savefile.pkl', {'dt': 0.5})
    trial_func = generate_trial
    trial_args = {'name': 'test', 'catch': False, 'coh': 16, 'left_right': 1}
    info       = rnn.run(inputs=(trial_func, trial_args))
# Termination criterion
TARGET_PERFORMANCE = 90
def terminate(performance_history):
    return np.mean(performance_history[-3:]) >= TARGET_PERFORMANCE

# Validation dataset
n_validation = 100*(nconditions + 1)

if __name__ == '__main__':
    from pycog import Model
    
    model = Model(N=N, Nin=Nin, Nout=Nout, ei=ei, Crec=Crec, Cout=Cout,
                  generate_trial=generate_trial, 
                  n_validation=n_validation, performance=performance, terminate=terminate)
    model.train('workingMemory_savefile.pkl', seed=100, recover=False)

   #-------------------------------------------------------------------------------------
   # Plot
   #-------------------------------------------------------------------------------------

    from pycog          import RNN
    from pycog.figtools import Figure

    rng = np.random.RandomState(1066)
    rnn  = RNN('workingMemory_savefile.pkl', {'dt': 2})

    trial_args = {'name':  'test', 'catch': False, 'pair': (15, 30)}

    info = rnn.run(inputs=(generate_trial, trial_args), rng=rng)
Esempio n. 3
0
if __name__ == '__main__':
    from pycog import Model

    model = Model(N=N,
                  Nout=Nout,
                  ei=ei,
                  tau=tau,
                  dt=dt,
                  train_brec=train_brec,
                  train_bout=train_bout,
                  var_rec=var_rec,
                  generate_trial=generate_trial,
                  mode=mode,
                  n_validation=n_validation,
                  min_error=min_error)
    model.train('savefile.pkl', seed=100, recover=False)

    #-------------------------------------------------------------------------------------
    # Plot
    #-------------------------------------------------------------------------------------

    from pycog import RNN
    from pycog.figtools import Figure

    rnn = RNN('savefile.pkl', {'dt': 0.5, 'var_rec': 0.01**2})
    info = rnn.run(T=2 * period)

    fig = Figure()
    plot = fig.add()

    plot.plot(rnn.t / tau, rnn.z[0], color=Figure.colors('blue'))
Esempio n. 4
0
#=========================================================================================
# Train
#=========================================================================================

elif action == 'train':
    from pycog import Model

    # Model specification
    model = Model(modelfile=modelfile)

    # Avoid locks on the cluster
    compiledir = join(theanopath, '{}-{}'.format(name, int(time.time())))

    # Train
    model.train(savefile, seed=seed, compiledir=compiledir, gpus=gpus)

#=========================================================================================
# Test resting state
#=========================================================================================

elif action == 'restingstate':
    import numpy as np

    from pycog import RNN
    from pycog.figtools import Figure

    # Create RNN
    if 'init' in args:
        print("* Initial network.")
        base, ext = os.path.splitext(savefile)