Ejemplo n.º 1
0
def rnn_run(savefile, parameters, col, value):
    rnn = RNN(savefile, parameters)
    trial_args = {}
    for j in range(int(0.8 * N)):
        rnn.Wrec[j, col] = rnn.Wrec[j, col] - value

    info1 = rnn.run(inputs=(generate_trial, trial_args), seed=200)
    rnn_zs = np.zeros([Nout, len(rnn.z[0])])
    for j in range(Nout):
        rnn_zs[j, :] = rnn.z[j] / np.max(rnn.z[j])
    return rnn_zs
from __future__ import division

import numpy as np

import imp 

from pycog import tasktools

from pycog          import RNN
from pycog.figtools import Figure
   
m = imp.load_source('model', 'romo.py')

rng = np.random.RandomState(1066)
rnn  = RNN('romo_savefile.pkl', {'dt': 2})

trial_args = {'name':  'test', 'catch': False, 'fpair': (34, 26), 'gt_lt': '>'}
   
info = rnn.run(inputs=(m.generate_trial, trial_args), rng=rng)

fig  = Figure()
plot = fig.add()

epochs = info['epochs']
f1_start, f1_end = epochs['f1']
f2_start, f2_end = epochs['f2']
t0   = f1_start
tmin = 0
tmax = f2_end

t     = 1e-3*(rnn.t-t0)
Ejemplo n.º 3
0
#=========================================================================================
# Paths
#=========================================================================================

here     = get_here(__file__)
base     = get_parent(here)
figspath = join(here, 'figs')

modelfile = join(base, 'examples', 'models', 'rdm_dense.py')
savefile  = join(base, 'examples', 'work', 'data', 'rdm_dense', 'rdm_dense.pkl')

#=========================================================================================

m = imp.load_source('model', modelfile)

rnn = RNN(savefile, {'dt': 0.5}, verbose=False)

trial_func = m.generate_trial
trial_args = {
    'name':   'test',
    'catch':  False,
    'coh':    16,
    'in_out': 1
    }
info = rnn.run(inputs=(trial_func, trial_args), seed=10)

colors = ['orange', 'purple']

DT = 15

# Inputs
Ejemplo n.º 4
0
def run_trials(p, args):
    """
    Run trials.

    """
    # Model
    m = p['model']

    # Number of trials
    try:
        ntrials = int(args[0])
    except:
        ntrials = 100
    ntrials *= m.nconditions

    # RNN
    rng = np.random.RandomState(p['seed'])
    rnn = RNN(p['savefile'], {'dt': p['dt']}, verbose=False)

    # Trials
    w = len(str(ntrials))
    trials = []
    backspaces = 0
    try:
        for i in xrange(ntrials):
            # Condition
            b = i % m.nconditions
            k0, k1 = tasktools.unravel_index(b, (len(m.fpairs), len(m.gt_lts)))
            fpair = m.fpairs[k0]
            gt_lt = m.gt_lts[k1]

            # Trial
            trial_func = m.generate_trial
            trial_args = {
                'name': 'test',
                'catch': False,
                'fpair': fpair,
                'gt_lt': gt_lt
            }
            info = rnn.run(inputs=(trial_func, trial_args), rng=rng)

            # Display trial type
            if info['f1'] > info['f2']:
                gt_lt = '>'
            else:
                gt_lt = '<'
            s = ("Trial {:>{}}/{}: {:>2} {} {:>2}".format(
                i + 1, w, ntrials, info['f1'], gt_lt, info['f2']))
            sys.stdout.write(backspaces * '\b' + s)
            sys.stdout.flush()
            backspaces = len(s)

            # Save
            dt = rnn.t[1] - rnn.t[0]
            step = int(p['dt_save'] / dt)
            trial = {
                't': rnn.t[::step],
                'u': rnn.u[:, ::step],
                'r': rnn.r[:, ::step],
                'z': rnn.z[:, ::step],
                'info': info
            }
            trials.append(trial)
    except KeyboardInterrupt:
        pass
    print("")

    # Save all
    filename = get_trialsfile(p)
    with open(filename, 'wb') as f:
        pickle.dump(trials, f, pickle.HIGHEST_PROTOCOL)
    size = os.path.getsize(filename) * 1e-9
    print("[ {}.run_trials ] Trials saved to {} ({:.1f} GB)".format(
        THIS, filename, size))

    # Psychometric function
    psychometric_function(filename)
Ejemplo n.º 5
0
def run_trials(p, args):
    """
    Run trials.

    """
    # Model
    m = p['model']

    # ntrials: Number of trials for each condition
    try:
        ntrials = int(args[0])
    except:
        ntrials = 100

    ntrials *= m.nconditions

    # RNN
    rng = np.random.RandomState(p['seed'])
    rnn = RNN(p['savefile'], {'dt': p['dt']}, verbose=False)

    # Trials
    w = len(str(ntrials))
    trials = []
    backspaces = 0
    try:
        for i in xrange(ntrials):
            b = i % m.nconditions
            # All conditions
            intensity = m.intensity_range[b]

            # Trial
            trial_func = m.generate_trial
            trial_args = {
                'name': 'test',
                'catch': False,
                'intensity': intensity,
            }
            info = rnn.run(inputs=(trial_func, trial_args), rng=rng)

            # Display trial type
            s = ("Trial {:>{}}/{}: intentsity: {:>+3}".format(
                i + 1, w, ntrials, info['intensity']))
            sys.stdout.write(backspaces * '\b' + s)
            sys.stdout.flush()
            backspaces = len(s)

            # Save
            dt = rnn.t[1] - rnn.t[0]
            step = int(p['dt_save'] / dt)
            trial = {
                't': rnn.t[::step],
                'u': rnn.u[:, ::step],
                'r': rnn.r[:, ::step],
                'z': rnn.z[:, ::step],
                'info': info
            }
            trials.append(trial)
    except KeyboardInterrupt:
        pass
    print("")

    # Save all
    filename = get_trialsfile(p)
    with open(filename, 'wb') as f:
        pickle.dump(trials, f, pickle.HIGHEST_PROTOCOL)
    size = os.path.getsize(filename) * 1e-9
    print("[ {}.run_trials ] Trials saved to {} ({:.1f} GB)".format(
        THIS, filename, size))
Ejemplo n.º 6
0
n_validation = 100
#n_gradient = 1
#mode         = 'continuous'

if __name__ == '__main__':
    from pycog import RNN
    from pycog.figtools import Figure

    rng = np.random.RandomState(1234)  # Added by Alfred

    #    savefile = 'examples/work/data/delay_react/delay_react.pkl'
    #    savefile = 'examples/work/data/run_57000_lr1em3_1_1000_50/delay_react.pkl'
    savefile = 'examples/work/data/run_10000_lr1em3_1_1_100_10/delay_react.pkl'
    #    savefile = 'examples/work/data/run_52000_lr1em3_1_100_100/delay_react.pkl'

    rnn = RNN(savefile, {'dt': 0.5, 'var_rec': 0.01**2})
    trial_args = {}

    info1 = rnn.run(inputs=(generate_trial, trial_args), seed=200)
    Z0 = rnn.z

    #    signal_time
    #    delay = 500
    #    width = 20
    #    magnitude = 4
    #    Y = np.zeros((len(t), Nout)) # Output matrix

    #    for i in range(Nout):
    #        for tt in range(len(t)):
    #            Y[tt][i] = np.exp( -(tt - (signal_time + delay / Nout * (i + 1)))**2 / (2 * width**2)) * magnitude
Ejemplo n.º 7
0
                  train_bout=train_bout,
                  var_rec=var_rec,
                  generate_trial=generate_trial,
                  mode=mode,
                  n_validation=n_validation,
                  min_error=min_error)
    model.train('savefile.pkl', seed=100, recover=False)

    #-------------------------------------------------------------------------------------
    # Plot
    #-------------------------------------------------------------------------------------

    from pycog import RNN
    from pycog.figtools import Figure

    rnn = RNN('savefile.pkl', {'dt': 0.5, 'var_rec': 0.01**2})
    info = rnn.run(T=2 * period)

    fig = Figure()
    plot = fig.add()

    plot.plot(rnn.t / tau, rnn.z[0], color=Figure.colors('blue'))
    plot.xlim(rnn.t[0] / tau, rnn.t[-1] / tau)
    plot.ylim(0, 2)

    print rnn.t[0]
    print rnn.t[-1]
    plot.plot((rnn.t / tau)[:], (0.9 * np.power(rnn.t / (2 * period), 2))[:],
              color=Figure.colors('orange'))

    plot.xlabel(r'$t/\tau$')
Ejemplo n.º 8
0
from __future__ import division

import numpy as np

import imp

from pycog import tasktools

from pycog import RNN
from pycog.figtools import Figure

m = imp.load_source('model', 'romo.py')

rng = np.random.RandomState(1066)
rnn = RNN('romo_savefile.pkl', {'dt': 2})

trial_args = {'name': 'test', 'catch': False, 'fpair': (34, 26), 'gt_lt': '>'}

info = rnn.run(inputs=(m.generate_trial, trial_args), rng=rng)

fig = Figure()
plot = fig.add()

epochs = info['epochs']
f1_start, f1_end = epochs['f1']
f2_start, f2_end = epochs['f2']
t0 = f1_start
tmin = 0
tmax = f2_end

t = 1e-3 * (rnn.t - t0)
Ejemplo n.º 9
0
#=========================================================================================
# Paths
#=========================================================================================

here     = get_here(__file__)
base     = get_parent(here)
figspath = join(here, 'figs')

#-----------------------------------------------------------------------------------------
# Load RNNs to compare
#-----------------------------------------------------------------------------------------

datapath = join(base, 'examples', 'work', 'data')

savefile = join(datapath, 'mante', 'mante.pkl')
rnn1     = RNN(savefile, verbose=True)

savefile = join(datapath, 'mante_areas', 'mante_areas.pkl')
rnn2     = RNN(savefile, verbose=True)

# Load model
modelfile = join(base, 'examples', 'models', 'mante_areas.py')
m         = imp.load_source('model', modelfile)

#=========================================================================================
# Figure setup
#=========================================================================================

w   = 7.5
h   = 3.8
r   = w/h
if __name__ == '__main__':
    from pycog import Model

    model = Model(N=N, Nout=Nout, ei=ei, tau=tau, dt=dt,
                  train_brec=train_brec, train_bout=train_bout, var_rec=var_rec,
                  generate_trial=generate_trial,
                  mode=mode, n_validation=n_validation, min_error=min_error)
    model.train('savefile.pkl', seed=100, recover=False)

    #-------------------------------------------------------------------------------------
    # Plot
    #-------------------------------------------------------------------------------------

    from pycog          import RNN
    from pycog.figtools import Figure

    rnn  = RNN('savefile.pkl', {'dt': 0.5, 'var_rec': 0.01**2})
    info = rnn.run(T=20*period)

    fig  = Figure()
    plot = fig.add()

    plot.plot(rnn.t/tau, rnn.z[0], color=Figure.colors('blue'))
    plot.xlim(rnn.t[0]/tau, rnn.t[-1]/tau)
    plot.ylim(-1, 1)

    plot.xlabel(r'$t/\tau$')
    plot.ylabel('$\sin t$')

    fig.save(path='.', name='sinewave')
Ejemplo n.º 11
0
        desc = model
    plots[model].text_upper_center(desc, dy=0.05, fontsize=6.5)

#=========================================================================================
# Plot performance
#=========================================================================================

clr_target = Figure.colors('red')
clr_actual = '0.2'
clr_seeds = '0.8'

for model, _ in models:
    plot = plots[model]

    try:
        rnn = RNN(get_savefile(model), verbose=True)
    except SystemExit:
        continue

    xall = []

    ntrials = [int(costs[0]) for costs in rnn.costs_history]
    ntrials = np.asarray(ntrials, dtype=int) / int(1e4)
    performance = [costs[1][-1] for costs in rnn.costs_history]

    # Because the network is run continuously, the first validation run is meaningless.
    if 'lee' in model:
        ntrials = ntrials[1:]
        performance = performance[1:]

    # Get target performance
plot.arrow(pos[0], pos[1], alen, 0, **arrowp)
plot.xlim(-r_screen, r_screen)
plot.ylim(-r_screen, r_screen)

#=========================================================================================
# PCA
#=========================================================================================

plot = plots['pca']

#-----------------------------------------------------------------------------------------
# PCA analysis
#-----------------------------------------------------------------------------------------

# RNN
rnn = RNN(savefile, {'dt': 0.5}, verbose=False)

# Run each sequence separately
rnn.p['mode'] = None

# Turn off noise
rnn.p['var_in']  = 0
rnn.p['var_rec'] = 0

dt_save = 2
trials  = {}
for seq in xrange(1, 1+m.nseq):
    print('Sequence #{}, noiseless'.format(seq))

    # Trial
    trial_func = m.generate_trial
Ejemplo n.º 13
0
plot = plots['R']
plot.text_upper_center('Pos. tuned during $f_1$, neg. during $f_2$', dy=0.1, fontsize=7)

plot = plots['<']
plot.xlabel('Time from $f_1$ onset (sec)')
plot.ylabel('Input (a.u.)')

plot = plots['sig']
plot.ylabel('Prop. sig. tuned units')

#=========================================================================================
# Sample inputs
#=========================================================================================

rng = np.random.RandomState(1066)
rnn = RNN(savefile, {'dt': 2}, verbose=False)

trial_func = m.generate_trial
trial_args = {
    'name':  'test',
    'catch': False,
    'fpair': (34, 26),
    'gt_lt': '>'
    }
info = rnn.run(inputs=(trial_func, trial_args), rng=rng)

# Stimulus durations
epochs = info['epochs']
f1_start, f1_end = epochs['f1']
f2_start, f2_end = epochs['f2']
t0   = f1_start
def run_trials(p, args):        #args are the number of trials
    """
    Run trials.

    """
    # Model
    m = p['model']

    # Number of trials
    try:
        ntrials = int(args[0])
    except:
        ntrials = 100
    ntrials *= m.nconditions

    # RNN
    rng = np.random.RandomState(p['seed'])
    rnn = RNN(p['savefile'], {'dt': p['dt']}, verbose=False)

    # Trials
    w = len(str(ntrials))
    trials = []
    backspaces = 0
    try:
        for i in xrange(ntrials):
            # Condition
            b      = i % m.nconditions  #iterate through all conditions
            k0, k1 = tasktools.unravel_index(b, (len(m.pairs), 1))
            pair  = m.pairs[k0]

            # Trial
            trial_func = m.generate_trial
            trial_args = {
                'name':  'test',
                'catch': False,
                'pair': pair,
                }
            info = rnn.run(inputs=(trial_func, trial_args), rng=rng)

            # Display trial type
            s = ("Trial {:>{}}/{}: {:>2} {:>2}"
                 .format(i+1, w, ntrials, info['f1'], info['f2']))
            sys.stdout.write(backspaces*'\b' + s)
            sys.stdout.flush()
            backspaces = len(s)

            # Save
            dt    = rnn.t[1] - rnn.t[0]
            step  = int(p['dt_save']/dt)
            trial = {
                't':    rnn.t[::step],
                'u':    rnn.u[:,::step],
                'r':    rnn.r[:,::step],
                'z':    rnn.z[:,::step],
                'info': info
                }
            trials.append(trial)
    except KeyboardInterrupt:
        pass
    print("")

    # Save all
    filename = get_trialsfile(p)
    with open(filename, 'wb') as f:
        pickle.dump(trials, f, pickle.HIGHEST_PROTOCOL)
    size = os.path.getsize(filename)*1e-9
    print("[ {}.run_trials ] Trials saved to {} ({:.1f} GB)".format(THIS, filename, size))

    # Psychometric function
    psychometric_function(filename)
Ejemplo n.º 15
0
def run_trials(p, args):
    """
    Run trials.

    """
    # Model
    m = p['model']

    # Number of trials
    try:
        ntrials = int(args[0])
    except:
        ntrials = 100
    ntrials *= m.nconditions + 1

    # RNN
    rng = np.random.RandomState(p['seed'])
    rnn = RNN(p['savefile'], {'dt': p['dt']}, verbose=False)
    if 'ant_level' in p:
        rnn.Wrec[np.where(rnn.Wrec<0)] *= (1 - p['ant_level'])

    # Trials
    w = len(str(ntrials))
    trials = []
    backspaces = 0
    try:
        for i in xrange(ntrials):
            b = i % (m.nconditions + 1)
            if b == 0:
                # Zero-coherence condition
                coh    = 0
                in_out = rng.choice(m.in_outs)
            else:
                # All other conditions
                k1, k2 = tasktools.unravel_index(b-1, (len(m.cohs), len(m.in_outs)))
                coh    = m.cohs[k1]
                in_out = m.in_outs[k2]

            # Trial
            trial_func = m.generate_trial
            trial_args = {
                'name':   'test',
                'catch':  False,
                'coh':    coh,
                'in_out': in_out
                }
            info = rnn.run(inputs=(trial_func, trial_args), rng=rng)

            # Display trial type
            #if coh == 0:
            #    s = "Trial {:>{}}/{}: {:>3}".format(i+1, w, ntrials, info['coh'])
            #else:
            #    s = ("Trial {:>{}}/{}: {:>+3}"
            #         .format(i+1, w, ntrials, info['in_out']*info['coh']))
            
            # Update the user with the current progress
            if (i/100) % 1 == 0:
                print("We are {:.2f}% complete".format(100*i/ntrials)) 
                sys.stdout.flush()

            # Save
            dt    = rnn.t[1] - rnn.t[0]
            step  = int(p['dt_save']/dt)
            trial = {
                't':    rnn.t[::step],
                'u':    rnn.u[:,::step],
                'r':    rnn.r[:,::step],
                'z':    rnn.z[:,::step],
                'info': info
                }
            trials.append(trial)
    except KeyboardInterrupt:
        pass
    print("")

    # Save all
    filename = get_trialsfile(p)
    dump(filename, trials)
    size = os.path.getsize(filename)*1e-9
    print("[ {}.run_trials ] Trials saved to {} ({:.1f} GB)".format(THIS, filename, size))

    # Compute the psychometric function
    psychometric_function(filename)
    from pycog import Model
    
    model = Model(N=N, Nin=Nin, Nout=Nout, ei=ei, Crec=Crec, Cout=Cout,
                  generate_trial=generate_trial, 
                  n_validation=n_validation, performance=performance, terminate=terminate)
    model.train('workingMemory_savefile.pkl', seed=100, recover=False)

   #-------------------------------------------------------------------------------------
   # Plot
   #-------------------------------------------------------------------------------------

    from pycog          import RNN
    from pycog.figtools import Figure

    rng = np.random.RandomState(1066)
    rnn  = RNN('workingMemory_savefile.pkl', {'dt': 2})

    trial_args = {'name':  'test', 'catch': False, 'pair': (15, 30)}

    info = rnn.run(inputs=(generate_trial, trial_args), rng=rng)

    fig  = Figure()
    plot = fig.add()
    
    epochs = info['epochs']
    f1_start, f1_end = epochs['f1']
    f2_start, f2_end = epochs['f2']
    t0   = f1_start
    tmin = 0
    tmax = f2_end
    
def run_trials(p, args):
    # Model
    m = p['model']

    # Number of trials
    try:
        ntrials = int(args[0])
    except:
        ntrials = 100
    ntrials *= m.nconditions

    # RNN
    rng = np.random.RandomState(p['seed'])
    rnn = RNN(p['savefile'], {'dt': p['dt']}, verbose=False)

    w = len(str(ntrials))
    trials = []
    backspaces = 0
    try:
        for i in xrange(ntrials):
            # Condition
            k = tasktools.unravel_index(i % m.nconditions,
                                        (len(m.cohs), len(m.left_rights),
                                         len(m.cohs), len(m.left_rights),
                                         len(m.contexts)))
            coh_m        = m.cohs[k[0]]
            left_right_m = m.left_rights[k[1]]
            coh_c        = m.cohs[k[2]]
            left_right_c = m.left_rights[k[3]]
            context      = m.contexts[k[4]]

            # Trial
            trial_func = m.generate_trial
            trial_args = {
                'name':         'test',
                'catch':        False,
                'coh_m':        coh_m,
                'left_right_m': left_right_m,
                'coh_c':        coh_c,
                'left_right_c': left_right_c,
                'context':      context
                }
            info = rnn.run(inputs=(trial_func, trial_args), rng=rng)

            # Display trial type
            s = ("Trial {:>{}}/{}: ({}) m{:>+3}, c{:>+3}"
                 .format(i+1, w, ntrials, info['context'],
                         info['left_right_m']*info['coh_m'],
                         info['left_right_c']*info['coh_c']))
            sys.stdout.write(backspaces*'\b' + s)
            sys.stdout.flush()
            backspaces = len(s)

            # Save
            dt    = rnn.t[1] - rnn.t[0]
            step  = int(p['dt_save']/dt)
            trial = {
                't':    rnn.t[::step],
                'u':    rnn.u[:,::step],
                'r':    rnn.r[:,::step],
                'z':    rnn.z[:,::step],
                'info': info,
                }
            trials.append(trial)
    except KeyboardInterrupt:
        pass
    print("")

    # Save all
    filename = get_trialsfile(p)
    with open(filename, 'wb') as f:
        pickle.dump(trials, f, pickle.HIGHEST_PROTOCOL)
    size = os.path.getsize(filename)*1e-9
    print("[ {}.run_trials ] Trials saved to {} ({:.1f} GB)".format(THIS, filename, size))

    # Compute the psychometric function
    psychometric_function(filename)
Ejemplo n.º 18
0
Archivo: do.py Proyecto: grahamas/pycog
#=========================================================================================
# Test resting state
#=========================================================================================

elif action == 'restingstate':
    import numpy as np

    from pycog          import RNN
    from pycog.figtools import Figure

    # Create RNN
    if 'init' in args:
        print("* Initial network.")
        base, ext = os.path.splitext(savefile)
        savefile_init = base + '_init' + ext
        rnn = RNN(savefile_init, {'dt': dt}, verbose=True)
    else:
        rnn = RNN(savefile, {'dt': dt}, verbose=True)
    rnn.run(3e3, seed=seed)

    # Summary
    mean = np.mean(rnn.z)
    std  = np.std(rnn.z)
    print("Mean output: {:.6f}".format(mean))
    print("Std. output: {:.6f}".format(std))

    fig  = Figure()
    plot = fig.add()

    colors = [Figure.colors('blue'), Figure.colors('orange')]
    for i in xrange(rnn.z.shape[0]):
rdm.psychometric_function(dense_trialsfile, plot, ms=5)

# Dale, fixed
plot = plots['Cpsy']
rdm.psychometric_function(fixed_trialsfile, plot, ms=5)

#=========================================================================================
# Connection matrices
#=========================================================================================

for rnn, sortbyfile, s, dprimefile in zip([rnn_nodale, rnn_dense, rnn_fixed],
                                          [sortby_nodale, sortby_dense, sortby_fixed],
                                          ['A', 'B', 'C'],
                                          [dprime_nodale, dprime_dense, dprime_fixed]):
    idx = np.loadtxt(sortbyfile, dtype=int)
    RNN.plot_connection_matrix(plots[s+'in'], rnn.Win[idx,:],
                               smap_exc_in, smap_inh_in)
    RNN.plot_connection_matrix(plots[s+'rec'], rnn.Wrec[idx,:][:,idx],
                               smap_exc_rec, smap_inh_rec)
    RNN.plot_connection_matrix(plots[s+'out'], rnn.Wout[:,idx],
                               smap_exc_out, smap_inh_out)

    dprime = np.loadtxt(dprimefile)
    transitions = []
    for i in xrange(1, len(dprime)):
        if dprime[i-1] > 0 and dprime[i] < 0:
            transitions.append(i)

    plot = plots[s+'rec']
    if s == 'A':
        n = transitions[0]
        plot.text(n-0.5, -1.53, '|', ha='center', va='center', fontsize=8)
Ejemplo n.º 20
0
    #---------------------------------------------------------------------------

    return trial

# Performance measure: two-alternative forced choice
performance = tasktools.performance_2afc

# Terminate training when psychometric performance exceeds 85%
def terminate(performance_history):
    return np.mean(performance_history[-5:]) > 85

# Validation dataset size
n_validation = 100*(nconditions + 1)

#///////////////////////////////////////////////////////////////////////////////

if __name__ == '__main__':
    # Train model
    model = Model(Nin=Nin, N=N, Nout=Nout, ei=ei, Cout=Cout,
                  generate_trial=generate_trial,
                  performance=performance, terminate=terminate,
                  n_validation=n_validation)
    model.train('savefile.pkl')

    # Run the trained network with 16*3.2% = 51.2% coherence for choice 1
    rnn        = RNN('savefile.pkl', {'dt': 0.5})
    trial_func = generate_trial
    trial_args = {'name': 'test', 'catch': False, 'coh': 16, 'left_right': 1}
    info       = rnn.run(inputs=(trial_func, trial_args))
Ejemplo n.º 21
0
nodale_trialsfile = join(paper.scratchpath, 'rdm_nodale', 'trials',
                         'rdm_nodale_trials.pkl')
dense_trialsfile = join(paper.scratchpath, 'rdm_dense', 'trials',
                        'rdm_dense_trials.pkl')
fixed_trialsfile = join(paper.scratchpath, 'rdm_fixed', 'trials',
                        'rdm_fixed_trials.pkl')

#-----------------------------------------------------------------------------------------
# Load RNNs to compare
#-----------------------------------------------------------------------------------------

datapath = join(base, 'examples', 'work', 'data')

savefile = join(datapath, 'rdm_nodale', 'rdm_nodale.pkl')
rnn_nodale = RNN(savefile, verbose=True)
dprime_nodale = join(datapath, 'rdm_nodale', 'rdm_nodale_dprime.txt')
sortby_nodale = join(datapath, 'rdm_nodale', 'rdm_nodale_selectivity.txt')

savefile = join(datapath, 'rdm_dense', 'rdm_dense.pkl')
rnn_dense = RNN(savefile, verbose=True)
dprime_dense = join(datapath, 'rdm_dense', 'rdm_dense_dprime.txt')
sortby_dense = join(datapath, 'rdm_dense', 'rdm_dense_selectivity.txt')

savefile = join(datapath, 'rdm_fixed', 'rdm_fixed.pkl')
rnn_fixed = RNN(savefile, verbose=True)
dprime_fixed = join(datapath, 'rdm_fixed', 'rdm_fixed_dprime.txt')
sortby_fixed = join(datapath, 'rdm_fixed', 'rdm_fixed_selectivity.txt')

#=========================================================================================
# Figure setup
Ejemplo n.º 22
0
def do(action, args, p):
    """
    Manage tasks.

    """
    print("ACTION*:   " + str(action))
    print("ARGS*:     " + str(args))

    #-------------------------------------------------------------------------------------
    # Trials
    #-------------------------------------------------------------------------------------

    if action == 'trials':
        run_trials(p, args)

    #-------------------------------------------------------------------------------------
    # Sort
    #-------------------------------------------------------------------------------------

    elif action == 'sort_stim_onset':

        sort_trials(get_trialsfile(p), get_sortedfile_stim_onset(p))

    #-------------------------------------------------------------------------------------
    # activate state
    #-------------------------------------------------------------------------------------

    # TODO plot multiple units in the same figure
    # TODO replace units name with real neurons

    elif action == 'activatestate':

        # Model
        m = p['model']

        # Intensity
        try:
            intensity = float(args[0])
        except:
            intensity = 1

        # Plot unit
        try:
            unit = int(args[1])
            if unit == -1:
                unit = None
        except:
            unit = None

        # Create RNN
        if 'init' in args:
            print("* Initial network.")
            base, ext = os.path.splitext(p['savefile'])
            savefile_init = base + '_init' + ext
            rnn = RNN(savefile_init, {'dt': p['dt']}, verbose=True)
        else:
            rnn = RNN(p['savefile'], {'dt': p['dt']}, verbose=True)

        trial_func = p['model'].generate_trial
        trial_args = {
            'name': 'test',
            'catch': False,
            'intensity': intensity,
        }
        info = rnn.run(inputs=(trial_func, trial_args), seed=p['seed'])

        # Summary
        mean = np.mean(rnn.z)
        std = np.std(rnn.z)
        print("Intensity: {:.6f}".format(intensity))
        print("Mean output: {:.6f}".format(mean))
        print("Std. output: {:.6f}".format(std))

        # Figure setup
        x = 0.12
        y = 0.12
        w = 0.80
        h = 0.80
        dashes = [3.5, 1.5]

        t_forward = 1e-3 * np.array(info['epochs']['forward'])
        t_stimulus = 1e-3 * np.array(info['epochs']['stimulus'])
        t_reversal = 1e-3 * np.array(info['epochs']['reversal'])

        fig = Figure(w=4,
                     h=3,
                     axislabelsize=7,
                     labelpadx=5,
                     labelpady=5,
                     thickness=0.6,
                     ticksize=3,
                     ticklabelsize=6,
                     ticklabelpad=2)
        plots = {
            'in': fig.add([x, y + 0.72 * h, w, 0.3 * h]),
            'out': fig.add([x, y, w, 0.65 * h]),
        }

        plot = plots['in']
        plot.ylabel('Input', labelpad=7, fontsize=6.5)

        plot = plots['out']
        plot.xlabel('Time (sec)', labelpad=6.5)
        plot.ylabel('Output', labelpad=7, fontsize=6.5)

        # -----------------------------------------------------------------------------------------
        # Input
        # -----------------------------------------------------------------------------------------

        plot = plots['in']
        plot.axis_off('bottom')

        plot.plot(1e-3 * rnn.t, rnn.u[0], color=Figure.colors('red'), lw=0.5)
        plot.lim('y', rnn.u[0])
        plot.xlim(1e-3 * rnn.t[0], 1e-3 * rnn.t[-1])

        # -----------------------------------------------------------------------------------------
        # Output
        # -----------------------------------------------------------------------------------------

        plot = plots['out']

        # Outputs
        colors = [Figure.colors('orange'), Figure.colors('blue')]
        if unit is None:
            plot.plot(1e-3 * rnn.t,
                      rnn.z[0],
                      color=colors[0],
                      label='Forward module')
            plot.plot(1e-3 * rnn.t,
                      rnn.z[1],
                      color=colors[1],
                      label='Reversal module')
            plot.lim('y', np.ravel(rnn.z), lower=0)
        else:
            plot.plot(1e-3 * rnn.t,
                      rnn.r[unit],
                      color=colors[1],
                      label='unit ' + str(unit))
            plot.lim('y', np.ravel(rnn.r[unit]))

        plot.xlim(1e-3 * rnn.t[0], 1e-3 * rnn.t[-1])

        # Legend
        props = {'prop': {'size': 7}}
        plot.legend(bbox_to_anchor=(1.1, 1.6), **props)

        plot.vline(t_forward[-1],
                   color='0.2',
                   linestyle='--',
                   lw=1,
                   dashes=dashes)
        plot.vline(t_reversal[0],
                   color='0.2',
                   linestyle='--',
                   lw=1,
                   dashes=dashes)

        # Epochs
        plot.text(np.mean(t_forward),
                  plot.get_ylim()[1],
                  'forward',
                  ha='center',
                  va='center',
                  fontsize=7)
        plot.text(np.mean(t_stimulus),
                  plot.get_ylim()[1],
                  'stimulus',
                  ha='center',
                  va='center',
                  fontsize=7)
        plot.text(np.mean(t_reversal),
                  plot.get_ylim()[1],
                  'reversal',
                  ha='center',
                  va='center',
                  fontsize=7)

        if 'init' in args:
            savename = p['name'] + '_' + action + '_init'
        else:
            savename = p['name'] + '_' + action

        if unit is not None:
            savename += '_unit_' + str(unit)

        fig.save(path=p['figspath'], name=savename)
        fig.close()

    # -------------------------------------------------------------------------------------
    # Plot single-unit activity aligned to stimulus onset
    # -------------------------------------------------------------------------------------

    elif action == 'units_stim_onset':
        from glob import glob

        try:
            lower_bon = float(args[0])
        except:
            lower_bon = None

        try:
            higher_bon = float(args[1])
        except:
            higher_bon = None

        # Remove existing files
        unitpath = join(p['figspath'], 'units')
        filenames = glob(join(unitpath, p['name'] + '_stim_onset_unit*'))
        for filename in filenames:
            os.remove(filename)
            print("Removed {}".format(filename))

        # Load sorted trials
        sortedfile = get_sortedfile_stim_onset(p)
        with open(sortedfile) as f:
            t, sorted_trials = pickle.load(f)

        rnn = RNN(p['savefile'], {'dt': p['dt']}, verbose=True)
        trial_func = p['model'].generate_trial
        trial_args = {
            'name': 'test',
            'catch': False,
        }
        info = rnn.run(inputs=(trial_func, trial_args), seed=p['seed'])

        t_stimulus = np.array(info['epochs']['stimulus'])
        stimulus_d = t_stimulus[1] - t_stimulus[0]

        for i in xrange(p['model'].N):
            # Check if the unit does anything
            # active = False
            # for r in sorted_trials.values():
            #     if is_active(r[i]):
            #         active = True
            #         break
            # if not active:
            #     continue

            dashes = [3.5, 1.5]

            fig = Figure()
            plot = fig.add()

            # -----------------------------------------------------------------------------
            # Plot
            # -----------------------------------------------------------------------------

            plot_unit(i, sortedfile, plot, tmin=lower_bon, tmax=higher_bon)

            plot.xlabel('Time (ms)')
            plot.ylabel('Firing rate (a.u.)')

            props = {
                'prop': {
                    'size': 8
                },
                'handletextpad': 1.02,
                'labelspacing': 0.6
            }
            plot.legend(bbox_to_anchor=(0.18, 1), **props)

            plot.vline(0, color='0.2', linestyle='--', lw=1, dashes=dashes)
            plot.vline(stimulus_d,
                       color='0.2',
                       linestyle='--',
                       lw=1,
                       dashes=dashes)

            # Epochs
            plot.text(-np.mean((0, stimulus_d)),
                      plot.get_ylim()[1],
                      'forward',
                      ha='center',
                      va='center',
                      fontsize=7)
            plot.text(np.mean((0, stimulus_d)),
                      plot.get_ylim()[1],
                      'stimulus',
                      ha='center',
                      va='center',
                      fontsize=7)
            plot.text(3 * np.mean((0, stimulus_d)),
                      plot.get_ylim()[1],
                      'reversal',
                      ha='center',
                      va='center',
                      fontsize=7)

            # -----------------------------------------------------------------------------

            fig.save(path=unitpath,
                     name=p['name'] + '_stim_onset_unit{:03d}'.format(i))
            fig.close()

    #-------------------------------------------------------------------------------------
    # Selectivity
    #-------------------------------------------------------------------------------------

    elif action == 'selectivity':

        try:
            lower = float(args[0])
        except:
            lower = None

        try:
            higher = float(args[1])
        except:
            higher = None

        # Model
        m = p['model']

        trialsfile = get_trialsfile(p)
        dprime = get_choice_selectivity(trialsfile,
                                        lower_bon=lower,
                                        higher_bon=higher)

        def get_first(x, p):
            return x[:int(p * len(x))]

        psig = 0.25
        units = np.arange(len(dprime))
        try:
            idx = np.argsort(abs(dprime[m.EXC]))[::-1]
            exc = get_first(units[m.EXC][idx], psig)

            idx = np.argsort(abs(dprime[m.INH]))[::-1]
            inh = get_first(units[m.INH][idx], psig)

            idx = np.argsort(dprime[exc])[::-1]
            units_exc = list(exc[idx])

            idx = np.argsort(dprime[inh])[::-1]
            units_inh = list(units[inh][idx])

            units = units_exc + units_inh
            dprime = dprime[units]
        except AttributeError:
            idx = np.argsort(abs(dprime))[::-1]
            all = get_first(units[idx], psig)

            idx = np.argsort(dprime[all])[::-1]
            units = list(units[all][idx])
            dprime = dprime[units]

        # Save d'
        filename = get_dprimefile(p)
        np.savetxt(filename, dprime)
        print("[ {}.do ] d\' saved to {}".format(THIS, filename))

        # Save selectivity
        filename = get_selectivityfile(p)
        np.savetxt(filename, units, fmt='%d')
        print("[ {}.do ] Choice selectivity saved to {}".format(
            THIS, filename))

    #-------------------------------------------------------------------------------------

    else:
        print("[ {}.do ] Unrecognized action.".format(THIS))
def run_trials(p, args):
    """
    Run trials.

    """
    # Model
    m = p['model']

    # Number of trials
    try:
        ntrials = int(args[0])
    except:
        ntrials = 100
    ntrials *= m.nconditions + 1

    # RNN
    rng = np.random.RandomState(p['seed'])
    rnn = RNN(p['savefile'], {'dt': p['dt']}, verbose=False)

    # Trials
    w = len(str(ntrials))
    trials = []
    backspaces = 0
    try:
        for i in xrange(ntrials):
            b = i % (m.nconditions + 1)
            if b == 0:
                # Zero-coherence condition
                coh    = 0
                in_out = rng.choice(m.in_outs)
            else:
                # All other conditions
                k1, k2 = tasktools.unravel_index(b-1, (len(m.cohs), len(m.in_outs)))
                coh    = m.cohs[k1]
                in_out = m.in_outs[k2]

            # Trial
            trial_func = m.generate_trial
            trial_args = {
                'name':   'test',
                'catch':  False,
                'coh':    coh,
                'in_out': in_out
                }
            info = rnn.run(inputs=(trial_func, trial_args), rng=rng)

            # Display trial type
            if coh == 0:
                s = "Trial {:>{}}/{}: {:>3}".format(i+1, w, ntrials, info['coh'])
            else:
                s = ("Trial {:>{}}/{}: {:>+3}"
                     .format(i+1, w, ntrials, info['in_out']*info['coh']))
            sys.stdout.write(backspaces*'\b' + s)
            sys.stdout.flush()
            backspaces = len(s)

            # Save
            dt    = rnn.t[1] - rnn.t[0]
            step  = int(p['dt_save']/dt)
            trial = {
                't':    rnn.t[::step],
                'u':    rnn.u[:,::step],
                'r':    rnn.r[:,::step],
                'z':    rnn.z[:,::step],
                'info': info
                }
            trials.append(trial)
    except KeyboardInterrupt:
        pass
    print("")

    # Save all
    filename = get_trialsfile(p)
    with open(filename, 'wb') as f:
        pickle.dump(trials, f, pickle.HIGHEST_PROTOCOL)
    size = os.path.getsize(filename)*1e-9
    print("[ {}.run_trials ] Trials saved to {} ({:.1f} GB)".format(THIS, filename, size))

    # Compute the psychometric function
    psychometric_function(filename)
def run_trials(p, args):
    """
    Run trials.

    """
    # Model
    m = p['model']

    # Number of trials
    try:
        ntrials = int(args[0])
    except:
        ntrials = 100
    ntrials *= m.nconditions

    # RNN
    rng = np.random.RandomState(p['seed'])
    rnn = RNN(p['savefile'], {'dt': p['dt']}, verbose=False)

    # Trials
    w = len(str(ntrials))
    trials = []
    backspaces = 0
    try:
        for i in xrange(ntrials):
            b        = i % m.nconditions
            k1, k2   = tasktools.unravel_index(b, (len(m.modalities), len(m.freqs)))
            modality = m.modalities[k1]
            freq     = m.freqs[k2]

            # Trial
            trial_func = m.generate_trial
            trial_args = {
                'name':     'test',
                'catch':    False,
                'modality': modality,
                'freq':     freq
                }
            info = rnn.run(inputs=(trial_func, trial_args), rng=rng)

            # Display trial type
            if info['modality'] == 'v':
                s = "Trial {:>{}}/{}: v |{:>2}".format(i+1, w, ntrials, info['freq'])
            elif info['modality'] == 'a':
                s = "Trial {:>{}}/{}:  a|{:>2}".format(i+1, w, ntrials, info['freq'])
            else:
                s = "Trial {:>{}}/{}: va|{:>2}".format(i+1, w, ntrials, info['freq'])
            sys.stdout.write(backspaces*'\b' + s)
            sys.stdout.flush()
            backspaces = len(s)

            # Save
            dt    = rnn.t[1] - rnn.t[0]
            step  = int(p['dt_save']/dt)
            trial = {
                't':    rnn.t[::step],
                'u':    rnn.u[:,::step],
                'r':    rnn.r[:,::step],
                'z':    rnn.z[:,::step],
                'info': info
                }
            trials.append(trial)
    except KeyboardInterrupt:
        pass
    print("")

    # Save all
    filename = get_trialsfile(p)
    with open(filename, 'wb') as f:
        pickle.dump(trials, f, pickle.HIGHEST_PROTOCOL)
    size = os.path.getsize(filename)*1e-9
    print("[ {}.run_trials ] Trials saved to {} ({:.1f} GB)".format(THIS, filename, size))

    # Compute the psychometric function
    psychometric_function(filename)
plot = plots['modality']
plot.text_upper_center('Modality selectivity', dy=0.1, fontsize=7)

plot = plots['mixed']
plot.text_upper_center('Mixed selectivity', dy=0.1, fontsize=7)

#=========================================================================================
# Sample inputs
#=========================================================================================

freq0      = int(np.ceil(m.boundary))
boundary_v = m.baseline_in + m.scale_v_p(m.boundary)
boundary_a = m.baseline_in + m.scale_a_p(m.boundary)

rng = np.random.RandomState(1215)
rnn = RNN(savefile, {'dt': 0.5}, verbose=False)
trials = []
for i in xrange(3):
    trial_func = m.generate_trial
    trial_args = {
        'name':     'test',
        'catch':    False,
        'modality': ['v', 'a', 'va'][i],
        'freq':     freq0,
        }
    info = rnn.run(inputs=(trial_func, trial_args), rng=rng)

    dt    = rnn.t[1] - rnn.t[0]
    step  = int(5/dt)
    trial = {
        't':    rnn.t[::step],
Ejemplo n.º 26
0
def run_trials(p, args):
    # Model
    m = p['model']

    # Number of trials
    try:
        ntrials = int(args[0])
    except:
        ntrials = 100
    ntrials *= m.nconditions

    # RNN
    rng = np.random.RandomState(p['seed'])
    rnn = RNN(p['savefile'], {'dt': p['dt']}, verbose=False)

    w = len(str(ntrials))
    trials = []
    backspaces = 0
    try:
        for i in xrange(ntrials):
            # Condition
            k = tasktools.unravel_index(
                i % m.nconditions, (len(m.cohs), len(m.left_rights), len(
                    m.cohs), len(m.left_rights), len(m.contexts)))
            coh_m = m.cohs[k[0]]
            left_right_m = m.left_rights[k[1]]
            coh_c = m.cohs[k[2]]
            left_right_c = m.left_rights[k[3]]
            context = m.contexts[k[4]]

            # Trial
            trial_func = m.generate_trial
            trial_args = {
                'name': 'test',
                'catch': False,
                'coh_m': coh_m,
                'left_right_m': left_right_m,
                'coh_c': coh_c,
                'left_right_c': left_right_c,
                'context': context
            }
            info = rnn.run(inputs=(trial_func, trial_args), rng=rng)

            # Display trial type
            s = ("Trial {:>{}}/{}: ({}) m{:>+3}, c{:>+3}".format(
                i + 1, w, ntrials, info['context'],
                info['left_right_m'] * info['coh_m'],
                info['left_right_c'] * info['coh_c']))
            sys.stdout.write(backspaces * '\b' + s)
            sys.stdout.flush()
            backspaces = len(s)

            # Save
            dt = rnn.t[1] - rnn.t[0]
            step = int(p['dt_save'] / dt)
            trial = {
                't': rnn.t[::step],
                'u': rnn.u[:, ::step],
                'r': rnn.r[:, ::step],
                'z': rnn.z[:, ::step],
                'info': info,
            }
            trials.append(trial)
    except KeyboardInterrupt:
        pass
    print("")

    # Save all
    filename = get_trialsfile(p)
    with open(filename, 'wb') as f:
        pickle.dump(trials, f, pickle.HIGHEST_PROTOCOL)
    size = os.path.getsize(filename) * 1e-9
    print("[ {}.run_trials ] Trials saved to {} ({:.1f} GB)".format(
        THIS, filename, size))

    # Compute the psychometric function
    psychometric_function(filename)
Ejemplo n.º 27
0
    rnn_zs = np.zeros([Nout, len(rnn.z[0])])
    for j in range(Nout):
        rnn_zs[j, :] = rnn.z[j] / np.max(rnn.z[j])
    return rnn_zs


if __name__ == '__main__':
    from pycog import RNN
    from pycog.figtools import Figure

    rng = np.random.RandomState(1234)  # Added by Alfred
    savefile = 'examples/work/data/delay_react/delay_react.pkl'
    #    savefile = 'examples/work/data/run_10000_lr1em3_1_1_100_10/delay_react.pkl'
    parameters = {'dt': 0.5, 'var_rec': 0.01**2}

    rnn = RNN(savefile, parameters)
    trial_args = {}
    info = rnn.run(inputs=(generate_trial, trial_args), seed=200)
    Z0 = rnn.z

    rnn_zs0 = np.zeros([Nout, len(rnn.z[0])])
    for j in range(Nout):
        rnn_zs0[j, :] = rnn.z[j] / np.max(rnn.z[j])

    n_values = 21
    n_cols = int(0.2 * N)

    shifts = np.zeros([n_cols, n_values, 2])
    values = np.linspace(-0.2, 0.2, n_values)

    for i in range(n_cols):
Ejemplo n.º 28
0
def run_trials(p, args):
    """
    Run trials.

    """
    # Model
    m = p['model']

    # Number of trials
    try:
        ntrials = int(args[0])
    except:
        ntrials = 100
    ntrials *= m.nconditions + 1

    # RNN
    rng = np.random.RandomState(p['seed'])
    rnn = RNN(p['savefile'], {'dt': p['dt']}, verbose=False)

    # Trials
    w = len(str(ntrials))
    trials = []
    backspaces = 0
    try:
        for i in xrange(ntrials):
            b = i % (m.nconditions + 1)
            if b == 0:
                # Zero-coherence condition
                coh = 0
                in_out = rng.choice(m.in_outs)
            else:
                # All other conditions
                k1, k2 = tasktools.unravel_index(b - 1,
                                                 (len(m.cohs), len(m.in_outs)))
                coh = m.cohs[k1]
                in_out = m.in_outs[k2]

            # Trial
            trial_func = m.generate_trial
            trial_args = {
                'name': 'test',
                'catch': False,
                'coh': coh,
                'in_out': in_out
            }
            info = rnn.run(inputs=(trial_func, trial_args), rng=rng)

            # Display trial type
            if coh == 0:
                s = "Trial {:>{}}/{}: {:>3}".format(i + 1, w, ntrials,
                                                    info['coh'])
            else:
                s = ("Trial {:>{}}/{}: {:>+3}".format(
                    i + 1, w, ntrials, info['in_out'] * info['coh']))
            sys.stdout.write(backspaces * '\b' + s)
            sys.stdout.flush()
            backspaces = len(s)

            # Save
            dt = rnn.t[1] - rnn.t[0]
            step = int(p['dt_save'] / dt)
            trial = {
                't': rnn.t[::step],
                'u': rnn.u[:, ::step],
                'r': rnn.r[:, ::step],
                'z': rnn.z[:, ::step],
                'info': info
            }
            trials.append(trial)
    except KeyboardInterrupt:
        pass
    print("")

    # Save all
    filename = get_trialsfile(p)
    with open(filename, 'wb') as f:
        pickle.dump(trials, f, pickle.HIGHEST_PROTOCOL)
    size = os.path.getsize(filename) * 1e-9
    print("[ {}.run_trials ] Trials saved to {} ({:.1f} GB)".format(
        THIS, filename, size))

    # Compute the psychometric function
    psychometric_function(filename)
Ejemplo n.º 29
0
#=========================================================================================
# Test resting state
#=========================================================================================

elif action == 'restingstate':
    import numpy as np

    from pycog import RNN
    from pycog.figtools import Figure

    # Create RNN
    if 'init' in args:
        print("* Initial network.")
        base, ext = os.path.splitext(savefile)
        savefile_init = base + '_init' + ext
        rnn = RNN(savefile_init, {'dt': dt}, verbose=True)
    else:
        rnn = RNN(savefile, {'dt': dt}, verbose=True)
    rnn.run(3e3, seed=seed)

    # Summary
    mean = np.mean(rnn.z)
    std = np.std(rnn.z)
    print("Mean output: {:.6f}".format(mean))
    print("Std. output: {:.6f}".format(std))

    fig = Figure()
    plot = fig.add()

    colors = [Figure.colors('blue'), Figure.colors('orange')]
    for i in xrange(rnn.z.shape[0]):
Ejemplo n.º 30
0
        trial['outputs'] = Y

    return trial


min_error = 0.1

n_validation = 100

if __name__ == '__main__':
    from pycog import RNN
    from pycog.figtools import Figure

    rnn = RNN('work/data/multi_sequence8mod/multi_sequence8mod.pkl', {
        'dt': 0.5,
        'var_rec': 0.01**2,
        'var_in': np.array([0.003**2])
    })
    trial_args = {}
    info = rnn.run(inputs=(generate_trial, trial_args), seed=7423)

    fig = Figure()
    plot = fig.add()

    colors = [
        'red', 'green', 'yellow', 'orange', 'purple', 'cyan', 'magenta', 'pink'
    ]

    plot.plot(rnn.t / tau, rnn.u[0], color=Figure.colors('blue'))
    for i in range(Nout):
        plot.plot(rnn.t / tau,
def run_trials(p, args):
    """
    Run trials.

    """
    # Model
    m = p['model']

    # Number of trials
    try:
        ntrials = int(args[0])
    except:
        ntrials = 1
    ntrials *= m.nseq

    # RNN
    rng = np.random.RandomState(p['seed'])
    rnn = RNN(p['savefile'], {'dt': p['dt']}, verbose=False)

    # Trials
    w = len(str(ntrials))
    trials = []
    backspaces = 0
    try:
        for i in xrange(ntrials):
            b = i % m.nseq
            if b == 0:
                if not trials:
                    seqs = range(m.nseq)
                else:
                    seqs = rng.permutation(m.nseq)

            # Sequence number
            seq = seqs[b] + 1

            # Trial
            trial_func = m.generate_trial
            trial_args = {'name': 'test', 'seq': seq}
            info = rnn.run(inputs=(trial_func, trial_args), rng=rng)

            # Display trial type
            s = "Trial {:>{}}/{}: Sequence #{}".format(i+1, w, ntrials, info['seq'])
            sys.stdout.write(backspaces*'\b' + s)
            sys.stdout.flush()
            backspaces = len(s)

            # Add
            dt    = rnn.t[1] - rnn.t[0]
            step  = int(p['dt_save']/dt)
            trial = {
                't':    rnn.t[::step],
                'u':    rnn.u[:,::step],
                'r':    rnn.r[:,::step],
                'z':    rnn.z[:,::step],
                'info': info
                }
            trials.append(trial)
    except KeyboardInterrupt:
        pass
    print("")

    # Save all
    filename = get_trialsfile(p)
    with open(filename, 'wb') as f:
        pickle.dump(trials, f, pickle.HIGHEST_PROTOCOL)
    size = os.path.getsize(filename)*1e-9
    print("[ {}.run_trials ] Trials saved to {} ({:.1f} GB)".format(THIS, filename, size))