Exemplo n.º 1
0
    'coh':    16,
    'in_out': 1
    }
info = rnn.run(inputs=(trial_func, trial_args), seed=10)

colors = ['orange', 'purple']

DT = 15

# Inputs
for i, clr in enumerate(colors):
    fig  = Figure(w=2, h=1)
    plot = fig.add(None, 'none')

    plot.plot(rnn.t[::DT], rnn.u[i][::DT], color=Figure.colors(clr), lw=1)
    plot.ylim(0, 1.5)

    fig.save(path=figspath, name='fig1a_input{}'.format(i+1))
    fig.close()

# Outputs
for i, clr in enumerate(colors):
    fig  = Figure(w=2, h=1)
    plot = fig.add(None, 'none')

    plot.plot(rnn.t[::DT], rnn.z[i][::DT], color=Figure.colors(clr), lw=1)
    plot.ylim(0, 1.5)

    fig.save(path=figspath, name='fig1a_output{}'.format(i+1))
    fig.close()
Exemplo n.º 2
0
def do(action, args, p):
    """
    Manage tasks.

    """
    print("ACTION*:   " + str(action))
    print("ARGS*:     " + str(args))

    #-------------------------------------------------------------------------------------
    # Trials
    #-------------------------------------------------------------------------------------

    if action == 'trials':
        run_trials(p, args)

    #-------------------------------------------------------------------------------------
    # Psychometric function
    #-------------------------------------------------------------------------------------

    elif action == 'psychometric':
        threshold = False
        if 'threshold' in args:
            threshold = True

        fig = Figure()
        plot = fig.add()

        #---------------------------------------------------------------------------------
        # Plot
        #---------------------------------------------------------------------------------

        trialsfile = get_trialsfile(p)
        psychometric_function(trialsfile, plot, threshold=threshold)

        plot.xlabel(r'Percent coherence toward $T_\text{in}$')
        plot.ylabel(r'Percent $T_\text{in}$')

        #---------------------------------------------------------------------------------

        fig.save(path=p['figspath'], name=p['name'] + '_' + action)
        fig.close()

    #-------------------------------------------------------------------------------------
    # Sort
    #-------------------------------------------------------------------------------------

    elif action == 'sort_stim_onset':
        sort_trials_stim_onset(get_trialsfile(p), get_sortedfile_stim_onset(p))

    elif action == 'sort_response':
        sort_trials_response(get_trialsfile(p), get_sortedfile_response(p))

    #-------------------------------------------------------------------------------------
    # Plot single-unit activity aligned to stimulus onset
    #-------------------------------------------------------------------------------------

    elif action == 'units_stim_onset':
        from glob import glob

        # Remove existing files
        filenames = glob(join(p['figspath'], p['name'] + '_stim_onset_unit*'))
        for filename in filenames:
            os.remove(filename)
            print("Removed {}".format(filename))

        # Load sorted trials
        sortedfile = get_sortedfile_stim_onset(p)
        with open(sortedfile) as f:
            t, sorted_trials = pickle.load(f)

        for i in xrange(p['model'].N):
            # Check if the unit does anything
            active = False
            for r in sorted_trials.values():
                if is_active(r[i]):
                    active = True
                    break
            if not active:
                continue

            fig = Figure()
            plot = fig.add()

            #-----------------------------------------------------------------------------
            # Plot
            #-----------------------------------------------------------------------------

            plot_unit(i, sortedfile, plot)

            plot.xlabel('Time (ms)')
            plot.ylabel('Firing rate (a.u.)')

            props = {
                'prop': {
                    'size': 8
                },
                'handletextpad': 1.02,
                'labelspacing': 0.6
            }
            plot.legend(bbox_to_anchor=(0.18, 1), **props)

            #-----------------------------------------------------------------------------

            fig.save(path=p['figspath'],
                     name=p['name'] + '_stim_onset_unit{:03d}'.format(i))
            fig.close()

    #-------------------------------------------------------------------------------------
    # Plot single-unit activity aligned to response
    #-------------------------------------------------------------------------------------

    elif action == 'units_response':
        from glob import glob

        # Remove existing files
        filenames = glob(join(p['figspath'], p['name'] + '_response_unit*'))
        for filename in filenames:
            os.remove(filename)
            print("Removed {}".format(filename))

        # Load sorted trials
        sortedfile = get_sortedfile_response(p)
        with open(sortedfile) as f:
            t, sorted_trials = pickle.load(f)

        for i in xrange(p['model'].N):
            # Check if the unit does anything
            active = False
            for r in sorted_trials.values():
                if is_active(r[i]):
                    active = True
                    break
            if not active:
                continue

            fig = Figure()
            plot = fig.add()

            #-----------------------------------------------------------------------------
            # Plot
            #-----------------------------------------------------------------------------

            plot_unit(i, sortedfile, plot)

            plot.xlabel('Time (ms)')
            plot.ylabel('Firing rate (a.u.)')

            props = {
                'prop': {
                    'size': 8
                },
                'handletextpad': 1.02,
                'labelspacing': 0.6
            }
            plot.legend(bbox_to_anchor=(0.18, 1), **props)

            #-----------------------------------------------------------------------------

            fig.save(path=p['figspath'],
                     name=p['name'] + '_response_unit_{:03d}'.format(i))
            fig.close()

    #-------------------------------------------------------------------------------------
    # Selectivity
    #-------------------------------------------------------------------------------------

    elif action == 'selectivity':
        # Model
        m = p['model']

        trialsfile = get_trialsfile(p)
        dprime = get_choice_selectivity(trialsfile)

        def get_first(x, p):
            return x[:int(p * len(x))]

        psig = 0.25
        units = np.arange(len(dprime))
        try:
            idx = np.argsort(abs(dprime[m.EXC]))[::-1]
            exc = get_first(units[m.EXC][idx], psig)

            idx = np.argsort(abs(dprime[m.INH]))[::-1]
            inh = get_first(units[m.INH][idx], psig)

            idx = np.argsort(dprime[exc])[::-1]
            units_exc = list(exc[idx])

            idx = np.argsort(dprime[inh])[::-1]
            units_inh = list(units[inh][idx])

            units = units_exc + units_inh
            dprime = dprime[units]
        except AttributeError:
            idx = np.argsort(abs(dprime))[::-1]
            all = get_first(units[idx], psig)

            idx = np.argsort(dprime[all])[::-1]
            units = list(units[all][idx])
            dprime = dprime[units]

        # Save d'
        filename = get_dprimefile(p)
        np.savetxt(filename, dprime)
        print("[ {}.do ] d\' saved to {}".format(THIS, filename))

        # Save selectivity
        filename = get_selectivityfile(p)
        np.savetxt(filename, units, fmt='%d')
        print("[ {}.do ] Choice selectivity saved to {}".format(
            THIS, filename))

    #-------------------------------------------------------------------------------------

    else:
        print("[ {}.do ] Unrecognized action.".format(THIS))
Exemplo n.º 3
0
def do(action, args, p):
    """
    Manage tasks.

    """
    print("ACTION*:   " + str(action))
    print("ARGS*:     " + str(args))

    #-------------------------------------------------------------------------------------
    # Trials
    #-------------------------------------------------------------------------------------

    if action == 'trials':
        run_trials(p, args)

    #-------------------------------------------------------------------------------------
    # Psychometric function
    #-------------------------------------------------------------------------------------

    elif action == 'psychometric':
        fig = Figure()
        plot = fig.add()

        #---------------------------------------------------------------------------------
        # Plot
        #---------------------------------------------------------------------------------

        trialsfile = get_trialsfile(p)
        psychometric_function(trialsfile, plot)

        plot.xlabel('$f_1$ (Hz)')
        plot.ylabel('$f_2$ (Hz)')

        #---------------------------------------------------------------------------------

        fig.save(path=p['figspath'], name=p['name'] + '_' + action)
        fig.close()

    #-------------------------------------------------------------------------------------
    # Sort
    #-------------------------------------------------------------------------------------

    elif action == 'sort':
        sort_trials(get_trialsfile(p), get_sortedfile(p))

    #-------------------------------------------------------------------------------------
    # Plot single-unit activity aligned to stimulus onset
    #-------------------------------------------------------------------------------------

    elif action == 'units':
        from glob import glob

        # Remove existing files
        print("[ {}.do ]".format(THIS))
        filenames = glob('{}_unit*'.format(join(p['figspath'], p['name'])))
        for filename in filenames:
            os.remove(filename)
            print("  Removed {}".format(filename))

        # Load sorted trials
        sortedfile = get_sortedfile(p)
        with open(sortedfile) as f:
            t, sorted_trials = pickle.load(f)

        # Colormap
        cmap = mpl.cm.jet
        norm = mpl.colors.Normalize(vmin=p['model'].fmin, vmax=p['model'].fmax)
        smap = mpl.cm.ScalarMappable(norm, cmap)

        for i in xrange(p['model'].N):
            # Check if the unit does anything
            active = False
            for condition_averaged in sorted_trials.values():
                if is_active(condition_averaged[i]):
                    active = True
                    break
            if not active:
                continue

            fig = Figure()
            plot = fig.add()

            #-----------------------------------------------------------------------------
            # Plot
            #-----------------------------------------------------------------------------

            plot_unit(i, sortedfile, plot, smap)

            plot.xlabel('Time (ms)')
            plot.ylabel('Firing rate (a.u.)')

            props = {
                'prop': {
                    'size': 8
                },
                'handletextpad': 1.02,
                'labelspacing': 0.6
            }
            plot.legend(bbox_to_anchor=(0.18, 1), **props)

            #-----------------------------------------------------------------------------

            fig.save(path=p['figspath'],
                     name=p['name'] + '_unit{:03d}'.format(i))
            fig.close()

    #-------------------------------------------------------------------------------------

    else:
        print("[ {}.do ] Unrecognized action '{}'.".format(THIS, action))
Exemplo n.º 4
0
def do(action, args, p):
    print("ACTION*:   " + str(action))
    print("ARGS*:     " + str(args))

    #-------------------------------------------------------------------------------------
    # Trials
    #-------------------------------------------------------------------------------------

    if action == 'trials':
        run_trials(p, args)

    #-------------------------------------------------------------------------------------
    # Psychometric function
    #-------------------------------------------------------------------------------------

    elif action == 'psychometric':
        #---------------------------------------------------------------------------------
        # Figure setup
        #---------------------------------------------------------------------------------

        w = 6.5
        h = 3
        fig = Figure(w=w,
                     h=h,
                     axislabelsize=7.5,
                     labelpadx=6,
                     labelpady=7.5,
                     thickness=0.6,
                     ticksize=3,
                     ticklabelsize=6.5,
                     ticklabelpad=2)

        w = 0.39
        h = 0.7

        L = 0.09
        R = L + w + 0.1
        y = 0.2

        plots = {'m': fig.add([L, y, w, h]), 'c': fig.add([R, y, w, h])}

        #---------------------------------------------------------------------------------
        # Labels
        #---------------------------------------------------------------------------------

        plot = plots['m']
        plot.xlabel('Motion coherence (\%)')
        plot.ylabel('Choice to right (\%)')

        plot = plots['c']
        plot.xlabel('Color coherence (\%)')
        plot.ylabel('Choice to green (\%)')

        #---------------------------------------------------------------------------------
        # Plot
        #---------------------------------------------------------------------------------

        trialsfile = get_trialsfile(p)
        psychometric_function(trialsfile, plots)

        # Legend
        prop = {
            'prop': {
                'size': 7
            },
            'handlelength': 1.2,
            'handletextpad': 1.1,
            'labelspacing': 0.5
        }
        plots['m'].legend(bbox_to_anchor=(0.41, 1), **prop)

        #---------------------------------------------------------------------------------

        fig.save(path=p['figspath'], name=p['name'] + '_' + action)
        fig.close()

    #-------------------------------------------------------------------------------------
    # Sort
    #-------------------------------------------------------------------------------------

    elif action == 'sort':
        trialsfile = get_trialsfile(p)
        sortedfile = get_sortedfile(p)
        sort_trials(trialsfile, sortedfile)

    #-------------------------------------------------------------------------------------
    # Plot single-unit activity
    #-------------------------------------------------------------------------------------

    elif action == 'units':
        from glob import glob

        # Remove existing files
        print("[ {}.do ]".format(THIS))
        filenames = glob('{}_unit*'.format(join(p['figspath'], p['name'])))
        for filename in filenames:
            os.remove(filename)
            print("  Removed {}".format(filename))

        trialsfile = get_trialsfile(p)
        sortedfile = get_sortedfile(p)

        units = get_active_units(trialsfile)
        for unit in units:
            #-----------------------------------------------------------------------------
            # Figure setup
            #-----------------------------------------------------------------------------

            w = 2.5
            h = 6
            fig = Figure(w=w,
                         h=h,
                         axislabelsize=7.5,
                         labelpadx=6,
                         labelpady=7.5,
                         thickness=0.6,
                         ticksize=3,
                         ticklabelsize=6.5,
                         ticklabelpad=2)

            w = 0.55
            x0 = 0.3

            h = 0.17
            dy = h + 0.06
            y0 = 0.77
            y1 = y0 - dy
            y2 = y1 - dy
            y3 = y2 - dy

            plots = {
                'choice': fig.add([x0, y0, w, h]),
                'motion_choice': fig.add([x0, y1, w, h]),
                'colour_choice': fig.add([x0, y2, w, h]),
                'context_choice': fig.add([x0, y3, w, h])
            }

            #-----------------------------------------------------------------------------
            # Plot
            #-----------------------------------------------------------------------------

            plot_unit(unit, sortedfile, plots, sortby_fontsize=7)
            plots['context_choice'].xlabel('Time (ms)')

            #-----------------------------------------------------------------------------

            fig.save(path=p['figspath'],
                     name=p['name'] + '_unit{:03d}'.format(unit))
            fig.close()
        print("[ {}.do ] {} units processed.".format(THIS, len(units)))

    #-------------------------------------------------------------------------------------
    # Regress
    #-------------------------------------------------------------------------------------

    elif action == 'regress':
        trialsfile = get_trialsfile(p)
        sortedfile = get_sortedfile(p)
        betafile = get_betafile(p)
        regress(trialsfile, sortedfile, betafile)

    #-------------------------------------------------------------------------------------
    elif action == 'inactivation':

        # load Weight value info
        import cPickle as pickle
        n_ne = 120

        try:
            w_inact_bin = int(args[1])
        except:
            w_inact_bin = 10

        try:
            w_inact_dep = int(args[2])
        except:
            w_inact_dep = int(n_ne / w_inact_bin)

        fn = 'C:\\Users\\sasak\\OneDrive\\workspace\\pycog\\examples\\work\\data\\mante\\mante_init.pkl'
        with open(fn, 'rb') as f:
            data = pickle.load(f)
        Wrec_init = data['best']['params'][1]

        fn = 'C:\\Users\\sasak\\OneDrive\\workspace\\pycog\\examples\\work\\data\\mante\\mante.pkl'
        with open(fn, 'rb') as f:
            data = pickle.load(f)
        Wrec_last = data['best']['params'][1]

        Wrec_diff = Wrec_init - Wrec_last
        Wee_diff = Wrec_diff[:n_ne, :n_ne]

        run_trials_inact(p, None, args)  # full condition

        # sort descending order
        Neu_rec_diff = np.mean(np.abs(Wee_diff), axis=1)
        sorted_id = np.argsort(Neu_rec_diff)[::-1]

        for i in xrange(w_inact_dep):
            id_inact = sorted_id[:(i + 1) * w_inact_bin]
            run_trials_inact(p, id_inact, args)

        # sort ascending order
        Neu_rec_diff = np.mean(np.abs(Wee_diff), axis=1)
        sorted_id = np.argsort(Neu_rec_diff)

        for i in xrange(w_inact_dep):
            id_inact = sorted_id[:(i + 1) * w_inact_bin]
            run_trials_inact(p, id_inact, args)

        # shuffled order
        rng = np.random.RandomState(p['seed'])
        #rng = np.random.RandomState(data['params']['seed'])
        sorted_id = range(n_ne)
        rng.shuffle(sorted_id)

        for i in xrange(w_inact_dep):
            id_inact = sorted_id[:(i + 1) * w_inact_bin]
            run_trials_inact(p, id_inact, args)
    #-------------------------------------------------------------------------------------

    else:
        print("[ {}.do ] Unrecognized action.".format(THIS))
def do(action, args, p):
    """
    Manage tasks.

    """
    print("ACTION*:   " + str(action))
    print("ARGS*:     " + str(args))

    #-------------------------------------------------------------------------------------
    # Trials
    #-------------------------------------------------------------------------------------

    if action == 'trials':
        run_trials(p, args)

    #-------------------------------------------------------------------------------------
    # Psychometric function
    #-------------------------------------------------------------------------------------

    elif action == 'psychometric':
        fig  = Figure()
        plot = fig.add()

        #---------------------------------------------------------------------------------
        # Plot
        #---------------------------------------------------------------------------------

        trialsfile = get_trialsfile(p)
        psychometric_function(trialsfile, plot)

        plot.xlabel(r'Rate (events/sec)')
        plot.ylabel(r'Percent high')

        #---------------------------------------------------------------------------------

        fig.save(path=p['figspath'], name=p['name']+'_'+action)
        fig.close()

    #-------------------------------------------------------------------------------------
    # Sort
    #-------------------------------------------------------------------------------------

    elif action == 'sort':
        trialsfile = get_trialsfile(p)
        sortedfile = get_sortedfile(p)
        sort_trials(trialsfile, sortedfile)

    #-------------------------------------------------------------------------------------
    # Plot single-unit activity
    #-------------------------------------------------------------------------------------

    elif action == 'units':
        from glob import glob

        # Remove existing files
        print("[ {}.do ]".format(THIS))
        filenames = glob('{}_unit*'.format(join(p['figspath'], p['name'])))
        for filename in filenames:
            os.remove(filename)
            print("  Removed {}".format(filename))

        # Load sorted trials
        sortedfile = get_sortedfile(p)
        with open(sortedfile) as f:
            t, sorted_trials = pickle.load(f)

        for i in xrange(p['model'].N):
            # Check if the unit does anything
            active = False
            for r in sorted_trials.values():
                if is_active(r[i]):
                    active = True
                    break
            if not active:
                continue

            fig  = Figure()
            plot = fig.add()

            #---------------------------------------------------------------------------------
            # Plot
            #---------------------------------------------------------------------------------

            plot_unit(i, sortedfile, plot)

            plot.xlabel('Time (ms)')
            plot.ylabel('Firing rate (a.u.)')

            #---------------------------------------------------------------------------------

            fig.save(path=p['figspath'], name=p['name']+'_unit{:03d}'.format(i))
            fig.close()

    #-------------------------------------------------------------------------------------

    else:
        print("[ {}.do ] Unrecognized action.".format(THIS))
Exemplo n.º 6
0
def plot_trial(trial_info, trial, figspath, name):
    U, Z, A, R, M, init, states_0, perf = trial_info
    U = U[:,0,:]
    Z = Z[:,0,:]
    A = A[:,0,:]
    R = R[:,0]
    M = M[:,0]
    t = int(np.sum(M))

    w = 0.65
    h = 0.18
    x = 0.17
    dy = h + 0.05
    y0 = 0.08
    y1 = y0 + dy
    y2 = y1 + dy
    y3 = y2 + dy

    fig   = Figure(h=6)
    plots = {'observables': fig.add([x, y3, w, h]),
             'policy':      fig.add([x, y2, w, h]),
             'actions':     fig.add([x, y1, w, h]),
             'rewards':     fig.add([x, y0, w, h])}

    time        = trial['time']
    dt          = time[1] - time[0]
    act_time    = time[:t]
    obs_time    = time[:t-1] + dt
    reward_time = act_time + dt
    xlim        = (0, max(time))

    #-------------------------------------------------------------------------------------
    # Observables
    #-------------------------------------------------------------------------------------

    plot = plots['observables']
    plot.plot(obs_time, U[:t-1,0], 'o', ms=5, mew=0, mfc=Figure.colors('blue'))
    plot.plot(obs_time, U[:t-1,0], lw=1.25, color=Figure.colors('blue'),   label='Keydown')
    plot.plot(obs_time, U[:t-1,1], 'o', ms=5, mew=0, mfc=Figure.colors('orange'))
    plot.plot(obs_time, U[:t-1,1], lw=1.25, color=Figure.colors('orange'), label=r'$f_\text{pos}$')
    plot.plot(obs_time, U[:t-1,2], 'o', ms=5, mew=0, mfc=Figure.colors('purple'))
    plot.plot(obs_time, U[:t-1,2], lw=1.25, color=Figure.colors('purple'), label=r'$f_\text{neg}$')
    plot.xlim(*xlim)
    plot.ylim(0, 1)
    plot.ylabel('Observables')

    if trial['gt_lt'] == '>':
        f1, f2 = trial['fpair']
    else:
        f2, f1 = trial['fpair']
    plot.text_upper_right(str((f1, f2)))

    #coh = trial['left_right']*trial['coh']
    #if coh < 0:
    #    color = Figure.colors('orange')
    #elif coh > 0:
    #    color = Figure.colors('purple')
    #else:
    #    color = Figure.colors('k')
    #plot.text_upper_right('Coh = {:.1f}\%'.format(coh), color=color)

    props = {'prop': {'size': 7}, 'handlelength': 1.2,
             'handletextpad': 1.2, 'labelspacing': 0.8}
    plot.legend(bbox_to_anchor=(1.2, 0.8), **props)

    #-------------------------------------------------------------------------------------
    # Policy
    #-------------------------------------------------------------------------------------

    plot = plots['policy']
    plot.plot(act_time, Z[:t,0], 'o', ms=5, mew=0, mfc=Figure.colors('blue'))
    plot.plot(act_time, Z[:t,0], lw=1.25, color=Figure.colors('blue'),
              label='Keydown')
    plot.plot(act_time, Z[:t,1], 'o', ms=5, mew=0, mfc=Figure.colors('orange'))
    plot.plot(act_time, Z[:t,1], lw=1.25, color=Figure.colors('orange'),
              label='$f_1 > f_2$')
    plot.plot(act_time, Z[:t,2], 'o', ms=5, mew=0, mfc=Figure.colors('purple'))
    plot.plot(act_time, Z[:t,2], lw=1.25, color=Figure.colors('purple'),
              label='$f_1 < f_2$')
    plot.xlim(*xlim)
    plot.ylim(0, 1)
    plot.ylabel('Action probabilities')

    props = {'prop': {'size': 7}, 'handlelength': 1.2,
             'handletextpad': 1.2, 'labelspacing': 0.8}
    plot.legend(bbox_to_anchor=(1.27, 0.8), **props)

    #-------------------------------------------------------------------------------------
    # Actions
    #-------------------------------------------------------------------------------------

    plot = plots['actions']
    actions = [np.argmax(a) for a in A[:t]]
    plot.plot(act_time, actions, 'o', ms=5, mew=0, mfc=Figure.colors('red'))
    plot.plot(act_time, actions, lw=1.25, color=Figure.colors('red'))
    plot.xlim(*xlim)
    plot.ylim(0, 2)
    plot.yticks([0, 1, 2])
    plot.yticklabels(['Keydown', '$f_1 > f_2$', '$f_1 < f_2$'])
    plot.ylabel('Action')

    #-------------------------------------------------------------------------------------
    # Rewards
    #-------------------------------------------------------------------------------------

    plot = plots['rewards']
    plot.plot(reward_time, R[:t], 'o', ms=5, mew=0, mfc=Figure.colors('red'))
    plot.plot(reward_time, R[:t], lw=1.25, color=Figure.colors('red'))
    plot.xlim(*xlim)
    plot.ylim(R_TERMINATE, R_CORRECT)
    plot.xlabel('Time (ms)')
    plot.ylabel('Reward')

    #-------------------------------------------------------------------------------------

    fig.save(path=figspath, name=name)
    fig.close()
Exemplo n.º 7
0
def do(action, args, p):
    """
    Manage tasks.

    """
    print("ACTION*:   " + str(action))
    print("ARGS*:     " + str(args))

    #-------------------------------------------------------------------------------------
    # Trials
    #-------------------------------------------------------------------------------------

    if action == 'trials':
        run_trials(p, args)

    #-------------------------------------------------------------------------------------
    # Sort
    #-------------------------------------------------------------------------------------

    elif action == 'sort_stim_onset':

        sort_trials(get_trialsfile(p), get_sortedfile_stim_onset(p))

    #-------------------------------------------------------------------------------------
    # activate state
    #-------------------------------------------------------------------------------------

    # TODO plot multiple units in the same figure
    # TODO replace units name with real neurons

    elif action == 'activatestate':

        # Model
        m = p['model']

        # Intensity
        try:
            intensity = float(args[0])
        except:
            intensity = 1

        # Plot unit
        try:
            unit = int(args[1])
            if unit == -1:
                unit = None
        except:
            unit = None

        # Create RNN
        if 'init' in args:
            print("* Initial network.")
            base, ext = os.path.splitext(p['savefile'])
            savefile_init = base + '_init' + ext
            rnn = RNN(savefile_init, {'dt': p['dt']}, verbose=True)
        else:
            rnn = RNN(p['savefile'], {'dt': p['dt']}, verbose=True)

        trial_func = p['model'].generate_trial
        trial_args = {
            'name': 'test',
            'catch': False,
            'intensity': intensity,
        }
        info = rnn.run(inputs=(trial_func, trial_args), seed=p['seed'])

        # Summary
        mean = np.mean(rnn.z)
        std = np.std(rnn.z)
        print("Intensity: {:.6f}".format(intensity))
        print("Mean output: {:.6f}".format(mean))
        print("Std. output: {:.6f}".format(std))

        # Figure setup
        x = 0.12
        y = 0.12
        w = 0.80
        h = 0.80
        dashes = [3.5, 1.5]

        t_forward = 1e-3 * np.array(info['epochs']['forward'])
        t_stimulus = 1e-3 * np.array(info['epochs']['stimulus'])
        t_reversal = 1e-3 * np.array(info['epochs']['reversal'])

        fig = Figure(w=4,
                     h=3,
                     axislabelsize=7,
                     labelpadx=5,
                     labelpady=5,
                     thickness=0.6,
                     ticksize=3,
                     ticklabelsize=6,
                     ticklabelpad=2)
        plots = {
            'in': fig.add([x, y + 0.72 * h, w, 0.3 * h]),
            'out': fig.add([x, y, w, 0.65 * h]),
        }

        plot = plots['in']
        plot.ylabel('Input', labelpad=7, fontsize=6.5)

        plot = plots['out']
        plot.xlabel('Time (sec)', labelpad=6.5)
        plot.ylabel('Output', labelpad=7, fontsize=6.5)

        # -----------------------------------------------------------------------------------------
        # Input
        # -----------------------------------------------------------------------------------------

        plot = plots['in']
        plot.axis_off('bottom')

        plot.plot(1e-3 * rnn.t, rnn.u[0], color=Figure.colors('red'), lw=0.5)
        plot.lim('y', rnn.u[0])
        plot.xlim(1e-3 * rnn.t[0], 1e-3 * rnn.t[-1])

        # -----------------------------------------------------------------------------------------
        # Output
        # -----------------------------------------------------------------------------------------

        plot = plots['out']

        # Outputs
        colors = [Figure.colors('orange'), Figure.colors('blue')]
        if unit is None:
            plot.plot(1e-3 * rnn.t,
                      rnn.z[0],
                      color=colors[0],
                      label='Forward module')
            plot.plot(1e-3 * rnn.t,
                      rnn.z[1],
                      color=colors[1],
                      label='Reversal module')
            plot.lim('y', np.ravel(rnn.z), lower=0)
        else:
            plot.plot(1e-3 * rnn.t,
                      rnn.r[unit],
                      color=colors[1],
                      label='unit ' + str(unit))
            plot.lim('y', np.ravel(rnn.r[unit]))

        plot.xlim(1e-3 * rnn.t[0], 1e-3 * rnn.t[-1])

        # Legend
        props = {'prop': {'size': 7}}
        plot.legend(bbox_to_anchor=(1.1, 1.6), **props)

        plot.vline(t_forward[-1],
                   color='0.2',
                   linestyle='--',
                   lw=1,
                   dashes=dashes)
        plot.vline(t_reversal[0],
                   color='0.2',
                   linestyle='--',
                   lw=1,
                   dashes=dashes)

        # Epochs
        plot.text(np.mean(t_forward),
                  plot.get_ylim()[1],
                  'forward',
                  ha='center',
                  va='center',
                  fontsize=7)
        plot.text(np.mean(t_stimulus),
                  plot.get_ylim()[1],
                  'stimulus',
                  ha='center',
                  va='center',
                  fontsize=7)
        plot.text(np.mean(t_reversal),
                  plot.get_ylim()[1],
                  'reversal',
                  ha='center',
                  va='center',
                  fontsize=7)

        if 'init' in args:
            savename = p['name'] + '_' + action + '_init'
        else:
            savename = p['name'] + '_' + action

        if unit is not None:
            savename += '_unit_' + str(unit)

        fig.save(path=p['figspath'], name=savename)
        fig.close()

    # -------------------------------------------------------------------------------------
    # Plot single-unit activity aligned to stimulus onset
    # -------------------------------------------------------------------------------------

    elif action == 'units_stim_onset':
        from glob import glob

        try:
            lower_bon = float(args[0])
        except:
            lower_bon = None

        try:
            higher_bon = float(args[1])
        except:
            higher_bon = None

        # Remove existing files
        unitpath = join(p['figspath'], 'units')
        filenames = glob(join(unitpath, p['name'] + '_stim_onset_unit*'))
        for filename in filenames:
            os.remove(filename)
            print("Removed {}".format(filename))

        # Load sorted trials
        sortedfile = get_sortedfile_stim_onset(p)
        with open(sortedfile) as f:
            t, sorted_trials = pickle.load(f)

        rnn = RNN(p['savefile'], {'dt': p['dt']}, verbose=True)
        trial_func = p['model'].generate_trial
        trial_args = {
            'name': 'test',
            'catch': False,
        }
        info = rnn.run(inputs=(trial_func, trial_args), seed=p['seed'])

        t_stimulus = np.array(info['epochs']['stimulus'])
        stimulus_d = t_stimulus[1] - t_stimulus[0]

        for i in xrange(p['model'].N):
            # Check if the unit does anything
            # active = False
            # for r in sorted_trials.values():
            #     if is_active(r[i]):
            #         active = True
            #         break
            # if not active:
            #     continue

            dashes = [3.5, 1.5]

            fig = Figure()
            plot = fig.add()

            # -----------------------------------------------------------------------------
            # Plot
            # -----------------------------------------------------------------------------

            plot_unit(i, sortedfile, plot, tmin=lower_bon, tmax=higher_bon)

            plot.xlabel('Time (ms)')
            plot.ylabel('Firing rate (a.u.)')

            props = {
                'prop': {
                    'size': 8
                },
                'handletextpad': 1.02,
                'labelspacing': 0.6
            }
            plot.legend(bbox_to_anchor=(0.18, 1), **props)

            plot.vline(0, color='0.2', linestyle='--', lw=1, dashes=dashes)
            plot.vline(stimulus_d,
                       color='0.2',
                       linestyle='--',
                       lw=1,
                       dashes=dashes)

            # Epochs
            plot.text(-np.mean((0, stimulus_d)),
                      plot.get_ylim()[1],
                      'forward',
                      ha='center',
                      va='center',
                      fontsize=7)
            plot.text(np.mean((0, stimulus_d)),
                      plot.get_ylim()[1],
                      'stimulus',
                      ha='center',
                      va='center',
                      fontsize=7)
            plot.text(3 * np.mean((0, stimulus_d)),
                      plot.get_ylim()[1],
                      'reversal',
                      ha='center',
                      va='center',
                      fontsize=7)

            # -----------------------------------------------------------------------------

            fig.save(path=unitpath,
                     name=p['name'] + '_stim_onset_unit{:03d}'.format(i))
            fig.close()

    #-------------------------------------------------------------------------------------
    # Selectivity
    #-------------------------------------------------------------------------------------

    elif action == 'selectivity':

        try:
            lower = float(args[0])
        except:
            lower = None

        try:
            higher = float(args[1])
        except:
            higher = None

        # Model
        m = p['model']

        trialsfile = get_trialsfile(p)
        dprime = get_choice_selectivity(trialsfile,
                                        lower_bon=lower,
                                        higher_bon=higher)

        def get_first(x, p):
            return x[:int(p * len(x))]

        psig = 0.25
        units = np.arange(len(dprime))
        try:
            idx = np.argsort(abs(dprime[m.EXC]))[::-1]
            exc = get_first(units[m.EXC][idx], psig)

            idx = np.argsort(abs(dprime[m.INH]))[::-1]
            inh = get_first(units[m.INH][idx], psig)

            idx = np.argsort(dprime[exc])[::-1]
            units_exc = list(exc[idx])

            idx = np.argsort(dprime[inh])[::-1]
            units_inh = list(units[inh][idx])

            units = units_exc + units_inh
            dprime = dprime[units]
        except AttributeError:
            idx = np.argsort(abs(dprime))[::-1]
            all = get_first(units[idx], psig)

            idx = np.argsort(dprime[all])[::-1]
            units = list(units[all][idx])
            dprime = dprime[units]

        # Save d'
        filename = get_dprimefile(p)
        np.savetxt(filename, dprime)
        print("[ {}.do ] d\' saved to {}".format(THIS, filename))

        # Save selectivity
        filename = get_selectivityfile(p)
        np.savetxt(filename, units, fmt='%d')
        print("[ {}.do ] Choice selectivity saved to {}".format(
            THIS, filename))

    #-------------------------------------------------------------------------------------

    else:
        print("[ {}.do ] Unrecognized action.".format(THIS))
def do(action, args, p):
    """
    Manage tasks.

    """
    print("ACTION*:   " + str(action))
    print("ARGS*:     " + str(args))

    #-------------------------------------------------------------------------------------
    # Trials
    #-------------------------------------------------------------------------------------

    if action == 'trials':
        run_trials(p, args)

    #-------------------------------------------------------------------------------------
    # Psychometric function
    #-------------------------------------------------------------------------------------

    elif action == 'psychometric':
        fig  = Figure()
        plot = fig.add()

        #---------------------------------------------------------------------------------
        # Plot
        #---------------------------------------------------------------------------------

        trialsfile = get_trialsfile(p)
        psychometric_function(trialsfile, plot)

        plot.xlabel('$f_1$ (Hz)')
        plot.ylabel('$f_2$ (Hz)')

        #---------------------------------------------------------------------------------

        fig.save(path=p['figspath'], name=p['name']+'_'+action)
        fig.close()

    #-------------------------------------------------------------------------------------
    # Sort
    #-------------------------------------------------------------------------------------

    elif action == 'sort':
        sort_trials(get_trialsfile(p), get_sortedfile(p))

    #-------------------------------------------------------------------------------------
    # Plot single-unit activity aligned to stimulus onset
    #-------------------------------------------------------------------------------------

    elif action == 'units':
        from glob import glob

        # Remove existing files
        print("[ {}.do ]".format(THIS))
        filenames = glob('{}_unit*'.format(join(p['figspath'], p['name'])))
        for filename in filenames:
            os.remove(filename)
            print("  Removed {}".format(filename))

        # Load sorted trials
        sortedfile = get_sortedfile(p)
        with open(sortedfile) as f:
            t, sorted_trials = pickle.load(f)

        # Colormap
        cmap = mpl.cm.jet
        norm = mpl.colors.Normalize(vmin=p['model'].fmin, vmax=p['model'].fmax)
        smap = mpl.cm.ScalarMappable(norm, cmap)

        for i in xrange(p['model'].N):
            # Check if the unit does anything
            active = False
            for condition_averaged in sorted_trials.values():
                if is_active(condition_averaged[i]):
                    active = True
                    break
            if not active:
                continue

            fig  = Figure()
            plot = fig.add()

            #-----------------------------------------------------------------------------
            # Plot
            #-----------------------------------------------------------------------------

            plot_unit(i, sortedfile, plot, smap)

            plot.xlabel('Time (ms)')
            plot.ylabel('Firing rate (a.u.)')

            props = {'prop': {'size': 8}, 'handletextpad': 1.02, 'labelspacing': 0.6}
            plot.legend(bbox_to_anchor=(0.18, 1), **props)

            #-----------------------------------------------------------------------------

            fig.save(path=p['figspath'], name=p['name']+'_unit{:03d}'.format(i))
            fig.close()

    #-------------------------------------------------------------------------------------

    else:
        print("[ {}.do ] Unrecognized action '{}'.".format(THIS, action))
Exemplo n.º 9
0
def do(action, args, p):
    """
    Manage tasks.

    """
    print("ACTION*:   " + str(action))
    print("ARGS*:     " + str(args))

    #-------------------------------------------------------------------------------------
    # Trials
    #-------------------------------------------------------------------------------------

    if action == 'trials':
        run_trials(p, args)

    #-------------------------------------------------------------------------------------
    # Psychometric function
    #-------------------------------------------------------------------------------------

    elif action == 'psychometric':
        fig = Figure()
        plot = fig.add()

        #---------------------------------------------------------------------------------
        # Plot
        #---------------------------------------------------------------------------------

        trialsfile = get_trialsfile(p)
        psychometric_function(trialsfile, plot)

        plot.xlabel(r'Rate (events/sec)')
        plot.ylabel(r'Percent high')

        #---------------------------------------------------------------------------------

        fig.save(path=p['figspath'], name=p['name'] + '_' + action)
        fig.close()

    #-------------------------------------------------------------------------------------
    # Sort
    #-------------------------------------------------------------------------------------

    elif action == 'sort':
        trialsfile = get_trialsfile(p)
        sortedfile = get_sortedfile(p)
        sort_trials(trialsfile, sortedfile)

    #-------------------------------------------------------------------------------------
    # Plot single-unit activity
    #-------------------------------------------------------------------------------------

    elif action == 'units':
        from glob import glob

        # Remove existing files
        print("[ {}.do ]".format(THIS))
        filenames = glob('{}_unit*'.format(join(p['figspath'], p['name'])))
        for filename in filenames:
            os.remove(filename)
            print("  Removed {}".format(filename))

        # Load sorted trials
        sortedfile = get_sortedfile(p)
        with open(sortedfile) as f:
            t, sorted_trials = pickle.load(f)

        for i in xrange(p['model'].N):
            # Check if the unit does anything
            active = False
            for r in sorted_trials.values():
                if is_active(r[i]):
                    active = True
                    break
            if not active:
                continue

            fig = Figure()
            plot = fig.add()

            #---------------------------------------------------------------------------------
            # Plot
            #---------------------------------------------------------------------------------

            plot_unit(i, sortedfile, plot)

            plot.xlabel('Time (ms)')
            plot.ylabel('Firing rate (a.u.)')

            #---------------------------------------------------------------------------------

            fig.save(path=p['figspath'],
                     name=p['name'] + '_unit{:03d}'.format(i))
            fig.close()

    #-------------------------------------------------------------------------------------

    else:
        print("[ {}.do ] Unrecognized action.".format(THIS))
def do(action, args, p):
    """
    Manage tasks.

    """
    print("ACTION*:   " + str(action))
    print("ARGS*:     " + str(args))

    #-------------------------------------------------------------------------------------
    # Trials
    #-------------------------------------------------------------------------------------

    if action == 'trials':
        run_trials(p, args)

    #-------------------------------------------------------------------------------------
    # Psychometric function
    #-------------------------------------------------------------------------------------

    elif action == 'psychometric':
        threshold = False
        if 'threshold' in args:
            threshold = True

        fig  = Figure()
        plot = fig.add()

        #---------------------------------------------------------------------------------
        # Plot
        #---------------------------------------------------------------------------------

        trialsfile = get_trialsfile(p)
        psychometric_function(trialsfile, plot, threshold=threshold)

        plot.xlabel(r'Percent coherence toward $T_\text{in}$')
        plot.ylabel(r'Percent $T_\text{in}$')

        #---------------------------------------------------------------------------------

        fig.save(path=p['figspath'], name=p['name']+'_'+action)
        fig.close()

    #-------------------------------------------------------------------------------------
    # Sort
    #-------------------------------------------------------------------------------------

    elif action == 'sort_stim_onset':
        sort_trials_stim_onset(get_trialsfile(p), get_sortedfile_stim_onset(p))

    elif action == 'sort_response':
        sort_trials_response(get_trialsfile(p), get_sortedfile_response(p))

    #-------------------------------------------------------------------------------------
    # Plot single-unit activity aligned to stimulus onset
    #-------------------------------------------------------------------------------------

    elif action == 'units_stim_onset':
        from glob import glob

        # Remove existing files
        filenames = glob(join(p['figspath'], p['name'] + '_stim_onset_unit*'))
        for filename in filenames:
            os.remove(filename)
            print("Removed {}".format(filename))

        # Load sorted trials
        sortedfile = get_sortedfile_stim_onset(p)
        with open(sortedfile) as f:
            t, sorted_trials = pickle.load(f)

        for i in xrange(p['model'].N):
            # Check if the unit does anything
            active = False
            for r in sorted_trials.values():
                if is_active(r[i]):
                    active = True
                    break
            if not active:
                continue

            fig  = Figure()
            plot = fig.add()

            #-----------------------------------------------------------------------------
            # Plot
            #-----------------------------------------------------------------------------

            plot_unit(i, sortedfile, plot)

            plot.xlabel('Time (ms)')
            plot.ylabel('Firing rate (a.u.)')

            props = {'prop': {'size': 8}, 'handletextpad': 1.02, 'labelspacing': 0.6}
            plot.legend(bbox_to_anchor=(0.18, 1), **props)

            #-----------------------------------------------------------------------------

            fig.save(path=p['figspath'],
                     name=p['name']+'_stim_onset_unit{:03d}'.format(i))
            fig.close()

    #-------------------------------------------------------------------------------------
    # Plot single-unit activity aligned to response
    #-------------------------------------------------------------------------------------

    elif action == 'units_response':
        from glob import glob

        # Remove existing files
        filenames = glob(join(p['figspath'], p['name'] + '_response_unit*'))
        for filename in filenames:
            os.remove(filename)
            print("Removed {}".format(filename))

        # Load sorted trials
        sortedfile = get_sortedfile_response(p)
        with open(sortedfile) as f:
            t, sorted_trials = pickle.load(f)

        for i in xrange(p['model'].N):
            # Check if the unit does anything
            active = False
            for r in sorted_trials.values():
                if is_active(r[i]):
                    active = True
                    break
            if not active:
                continue

            fig  = Figure()
            plot = fig.add()

            #-----------------------------------------------------------------------------
            # Plot
            #-----------------------------------------------------------------------------

            plot_unit(i, sortedfile, plot)

            plot.xlabel('Time (ms)')
            plot.ylabel('Firing rate (a.u.)')

            props = {'prop': {'size': 8}, 'handletextpad': 1.02, 'labelspacing': 0.6}
            plot.legend(bbox_to_anchor=(0.18, 1), **props)

            #-----------------------------------------------------------------------------

            fig.save(path=p['figspath'],
                     name=p['name']+'_response_unit_{:03d}'.format(i))
            fig.close()

    #-------------------------------------------------------------------------------------
    # Selectivity
    #-------------------------------------------------------------------------------------

    elif action == 'selectivity':
        # Model
        m = p['model']

        trialsfile = get_trialsfile(p)
        dprime     = get_choice_selectivity(trialsfile)

        def get_first(x, p):
            return x[:int(p*len(x))]

        psig  = 0.25
        units = np.arange(len(dprime))
        try:
            idx = np.argsort(abs(dprime[m.EXC]))[::-1]
            exc = get_first(units[m.EXC][idx], psig)

            idx = np.argsort(abs(dprime[m.INH]))[::-1]
            inh = get_first(units[m.INH][idx], psig)

            idx = np.argsort(dprime[exc])[::-1]
            units_exc = list(exc[idx])

            idx = np.argsort(dprime[inh])[::-1]
            units_inh = list(units[inh][idx])

            units  = units_exc + units_inh
            dprime = dprime[units]
        except AttributeError:
            idx = np.argsort(abs(dprime))[::-1]
            all = get_first(units[idx], psig)

            idx    = np.argsort(dprime[all])[::-1]
            units  = list(units[all][idx])
            dprime = dprime[units]

        # Save d'
        filename = get_dprimefile(p)
        np.savetxt(filename, dprime)
        print("[ {}.do ] d\' saved to {}".format(THIS, filename))

        # Save selectivity
        filename = get_selectivityfile(p)
        np.savetxt(filename, units, fmt='%d')
        print("[ {}.do ] Choice selectivity saved to {}".format(THIS, filename))

    #-------------------------------------------------------------------------------------

    else:
        print("[ {}.do ] Unrecognized action.".format(THIS))
Exemplo n.º 11
0
def do(action, args, p):
    print("ACTION*:   " + str(action))
    print("ARGS*:     " + str(args))

    #-------------------------------------------------------------------------------------
    # Trials
    #-------------------------------------------------------------------------------------

    if action == 'trials':
        run_trials(p, args)

    #-------------------------------------------------------------------------------------
    # Psychometric function
    #-------------------------------------------------------------------------------------

    elif action == 'psychometric':
        #---------------------------------------------------------------------------------
        # Figure setup
        #---------------------------------------------------------------------------------

        w = 6.5
        h = 3
        fig = Figure(w=w,
                     h=h,
                     axislabelsize=7.5,
                     labelpadx=6,
                     labelpady=7.5,
                     thickness=0.6,
                     ticksize=3,
                     ticklabelsize=6.5,
                     ticklabelpad=2)

        w = 0.39
        h = 0.7

        L = 0.09
        R = L + w + 0.1
        y = 0.2

        plots = {'m': fig.add([L, y, w, h]), 'c': fig.add([R, y, w, h])}

        #---------------------------------------------------------------------------------
        # Labels
        #---------------------------------------------------------------------------------

        plot = plots['m']
        plot.xlabel('Motion coherence (\%)')
        plot.ylabel('Choice to right (\%)')

        plot = plots['c']
        plot.xlabel('Color coherence (\%)')
        plot.ylabel('Choice to green (\%)')

        #---------------------------------------------------------------------------------
        # Plot
        #---------------------------------------------------------------------------------

        trialsfile = get_trialsfile(p)
        psychometric_function(trialsfile, plots)

        # Legend
        prop = {
            'prop': {
                'size': 7
            },
            'handlelength': 1.2,
            'handletextpad': 1.1,
            'labelspacing': 0.5
        }
        plots['m'].legend(bbox_to_anchor=(0.41, 1), **prop)

        #---------------------------------------------------------------------------------

        fig.save(path=p['figspath'], name=p['name'] + '_' + action)
        fig.close()

    #-------------------------------------------------------------------------------------
    # Sort
    #-------------------------------------------------------------------------------------

    elif action == 'sort':
        trialsfile = get_trialsfile(p)
        sortedfile = get_sortedfile(p)
        sort_trials(trialsfile, sortedfile)

    #-------------------------------------------------------------------------------------
    # Plot single-unit activity
    #-------------------------------------------------------------------------------------

    elif action == 'units':
        from glob import glob

        # Remove existing files
        print("[ {}.do ]".format(THIS))
        filenames = glob('{}_unit*'.format(join(p['figspath'], p['name'])))
        for filename in filenames:
            os.remove(filename)
            print("  Removed {}".format(filename))

        trialsfile = get_trialsfile(p)
        sortedfile = get_sortedfile(p)

        units = get_active_units(trialsfile)
        for unit in units:
            #-----------------------------------------------------------------------------
            # Figure setup
            #-----------------------------------------------------------------------------

            w = 2.5
            h = 6
            fig = Figure(w=w,
                         h=h,
                         axislabelsize=7.5,
                         labelpadx=6,
                         labelpady=7.5,
                         thickness=0.6,
                         ticksize=3,
                         ticklabelsize=6.5,
                         ticklabelpad=2)

            w = 0.55
            x0 = 0.3

            h = 0.17
            dy = h + 0.06
            y0 = 0.77
            y1 = y0 - dy
            y2 = y1 - dy
            y3 = y2 - dy

            plots = {
                'choice': fig.add([x0, y0, w, h]),
                'motion_choice': fig.add([x0, y1, w, h]),
                'colour_choice': fig.add([x0, y2, w, h]),
                'context_choice': fig.add([x0, y3, w, h])
            }

            #-----------------------------------------------------------------------------
            # Plot
            #-----------------------------------------------------------------------------

            plot_unit(unit, sortedfile, plots, sortby_fontsize=7)
            plots['context_choice'].xlabel('Time (ms)')

            #-----------------------------------------------------------------------------

            fig.save(path=p['figspath'],
                     name=p['name'] + '_unit{:03d}'.format(unit))
            fig.close()
        print("[ {}.do ] {} units processed.".format(THIS, len(units)))

    #-------------------------------------------------------------------------------------
    # Regress
    #-------------------------------------------------------------------------------------

    elif action == 'regress':
        trialsfile = get_trialsfile(p)
        sortedfile = get_sortedfile(p)
        betafile = get_betafile(p)
        regress(trialsfile, sortedfile, betafile)

    #-------------------------------------------------------------------------------------

    else:
        print("[ {}.do ] Unrecognized action.".format(THIS))
Exemplo n.º 12
0
def plot_trial(trial_info, trial, figspath, name):
    U, Z, A, R, M, init, states_0, perf = trial_info
    U = U[:,0,:]
    Z = Z[:,0,:]
    A = A[:,0,:]
    R = R[:,0]
    M = M[:,0]
    t = int(np.sum(M))

    w = 0.65
    h = 0.18
    x = 0.17
    dy = h + 0.05
    y0 = 0.08
    y1 = y0 + dy
    y2 = y1 + dy
    y3 = y2 + dy

    fig   = Figure(h=6)
    plots = {'observables': fig.add([x, y3, w, h]),
             'policy':      fig.add([x, y2, w, h]),
             'actions':     fig.add([x, y1, w, h]),
             'rewards':     fig.add([x, y0, w, h])}

    time        = trial['time']
    dt          = time[1] - time[0]
    act_time    = time[:t]
    obs_time    = time[:t-1] + dt
    reward_time = act_time + dt
    xlim        = (0, max(time))

    #-------------------------------------------------------------------------------------
    # Observables
    #-------------------------------------------------------------------------------------

    plot = plots['observables']
    plot.plot(obs_time, U[:t-1,0], 'o', ms=5, mew=0, mfc=Figure.colors('blue'))
    plot.plot(obs_time, U[:t-1,0], lw=1.25, color=Figure.colors('blue'),   label='Keydown')
    plot.plot(obs_time, U[:t-1,1], 'o', ms=5, mew=0, mfc=Figure.colors('orange'))
    plot.plot(obs_time, U[:t-1,1], lw=1.25, color=Figure.colors('orange'), label=r'$f_\text{pos}$')
    plot.plot(obs_time, U[:t-1,2], 'o', ms=5, mew=0, mfc=Figure.colors('purple'))
    plot.plot(obs_time, U[:t-1,2], lw=1.25, color=Figure.colors('purple'), label=r'$f_\text{neg}$')
    plot.xlim(*xlim)
    plot.ylim(0, 1)
    plot.ylabel('Observables')

    if trial['gt_lt'] == '>':
        f1, f2 = trial['fpair']
    else:
        f2, f1 = trial['fpair']
    plot.text_upper_right(str((f1, f2)))

    #coh = trial['left_right']*trial['coh']
    #if coh < 0:
    #    color = Figure.colors('orange')
    #elif coh > 0:
    #    color = Figure.colors('purple')
    #else:
    #    color = Figure.colors('k')
    #plot.text_upper_right('Coh = {:.1f}\%'.format(coh), color=color)

    props = {'prop': {'size': 7}, 'handlelength': 1.2,
             'handletextpad': 1.2, 'labelspacing': 0.8}
    plot.legend(bbox_to_anchor=(1.2, 0.8), **props)

    #-------------------------------------------------------------------------------------
    # Policy
    #-------------------------------------------------------------------------------------

    plot = plots['policy']
    plot.plot(act_time, Z[:t,0], 'o', ms=5, mew=0, mfc=Figure.colors('blue'))
    plot.plot(act_time, Z[:t,0], lw=1.25, color=Figure.colors('blue'),
              label='Keydown')
    plot.plot(act_time, Z[:t,1], 'o', ms=5, mew=0, mfc=Figure.colors('orange'))
    plot.plot(act_time, Z[:t,1], lw=1.25, color=Figure.colors('orange'),
              label='$f_1 > f_2$')
    plot.plot(act_time, Z[:t,2], 'o', ms=5, mew=0, mfc=Figure.colors('purple'))
    plot.plot(act_time, Z[:t,2], lw=1.25, color=Figure.colors('purple'),
              label='$f_1 < f_2$')
    plot.xlim(*xlim)
    plot.ylim(0, 1)
    plot.ylabel('Action probabilities')

    props = {'prop': {'size': 7}, 'handlelength': 1.2,
             'handletextpad': 1.2, 'labelspacing': 0.8}
    plot.legend(bbox_to_anchor=(1.27, 0.8), **props)

    #-------------------------------------------------------------------------------------
    # Actions
    #-------------------------------------------------------------------------------------

    plot = plots['actions']
    actions = [np.argmax(a) for a in A[:t]]
    plot.plot(act_time, actions, 'o', ms=5, mew=0, mfc=Figure.colors('red'))
    plot.plot(act_time, actions, lw=1.25, color=Figure.colors('red'))
    plot.xlim(*xlim)
    plot.ylim(0, 2)
    plot.yticks([0, 1, 2])
    plot.yticklabels(['Keydown', '$f_1 > f_2$', '$f_1 < f_2$'])
    plot.ylabel('Action')

    #-------------------------------------------------------------------------------------
    # Rewards
    #-------------------------------------------------------------------------------------

    plot = plots['rewards']
    plot.plot(reward_time, R[:t], 'o', ms=5, mew=0, mfc=Figure.colors('red'))
    plot.plot(reward_time, R[:t], lw=1.25, color=Figure.colors('red'))
    plot.xlim(*xlim)
    plot.ylim(R_TERMINATE, R_CORRECT)
    plot.xlabel('Time (ms)')
    plot.ylabel('Reward')

    #-------------------------------------------------------------------------------------

    fig.save(path=figspath, name=name)
    fig.close()
def do(action, args, p):
    print("ACTION*:   " + str(action))
    print("ARGS*:     " + str(args))

    #-------------------------------------------------------------------------------------
    # Trials
    #-------------------------------------------------------------------------------------

    if action == 'trials':
        run_trials(p, args)

    #-------------------------------------------------------------------------------------
    # Psychometric function
    #-------------------------------------------------------------------------------------

    elif action == 'psychometric':
        #---------------------------------------------------------------------------------
        # Figure setup
        #---------------------------------------------------------------------------------

        w   = 6.5
        h   = 3
        fig = Figure(w=w, h=h, axislabelsize=7.5, labelpadx=6, labelpady=7.5,
                     thickness=0.6, ticksize=3, ticklabelsize=6.5, ticklabelpad=2)

        w = 0.39
        h = 0.7

        L = 0.09
        R = L + w + 0.1
        y = 0.2

        plots = {'m': fig.add([L, y, w, h]),
                 'c': fig.add([R, y, w, h])}

        #---------------------------------------------------------------------------------
        # Labels
        #---------------------------------------------------------------------------------

        plot = plots['m']
        plot.xlabel('Motion coherence (\%)')
        plot.ylabel('Choice to right (\%)')

        plot = plots['c']
        plot.xlabel('Color coherence (\%)')
        plot.ylabel('Choice to green (\%)')

        #---------------------------------------------------------------------------------
        # Plot
        #---------------------------------------------------------------------------------

        trialsfile = get_trialsfile(p)
        psychometric_function(trialsfile, plots)

        # Legend
        prop = {'prop': {'size': 7}, 'handlelength': 1.2,
                'handletextpad': 1.1, 'labelspacing': 0.5}
        plots['m'].legend(bbox_to_anchor=(0.41, 1), **prop)

        #---------------------------------------------------------------------------------

        fig.save(path=p['figspath'], name=p['name']+'_'+action)
        fig.close()

    #-------------------------------------------------------------------------------------
    # Sort
    #-------------------------------------------------------------------------------------

    elif action == 'sort':
        trialsfile = get_trialsfile(p)
        sortedfile = get_sortedfile(p)
        sort_trials(trialsfile, sortedfile)

    #-------------------------------------------------------------------------------------
    # Plot single-unit activity
    #-------------------------------------------------------------------------------------

    elif action == 'units':
        from glob import glob

        # Remove existing files
        print("[ {}.do ]".format(THIS))
        filenames = glob('{}_unit*'.format(join(p['figspath'], p['name'])))
        for filename in filenames:
            os.remove(filename)
            print("  Removed {}".format(filename))

        trialsfile = get_trialsfile(p)
        sortedfile = get_sortedfile(p)

        units = get_active_units(trialsfile)
        for unit in units:
            #-----------------------------------------------------------------------------
            # Figure setup
            #-----------------------------------------------------------------------------

            w   = 2.5
            h   = 6
            fig = Figure(w=w, h=h, axislabelsize=7.5, labelpadx=6, labelpady=7.5,
                         thickness=0.6, ticksize=3, ticklabelsize=6.5, ticklabelpad=2)

            w  = 0.55
            x0 = 0.3

            h  = 0.17
            dy = h + 0.06
            y0 = 0.77
            y1 = y0 - dy
            y2 = y1 - dy
            y3 = y2 - dy

            plots = {
                'choice':         fig.add([x0, y0, w, h]),
                'motion_choice':  fig.add([x0, y1, w, h]),
                'colour_choice':  fig.add([x0, y2, w, h]),
                'context_choice': fig.add([x0, y3, w, h])
                }

            #-----------------------------------------------------------------------------
            # Plot
            #-----------------------------------------------------------------------------

            plot_unit(unit, sortedfile, plots, sortby_fontsize=7)
            plots['context_choice'].xlabel('Time (ms)')

            #-----------------------------------------------------------------------------

            fig.save(path=p['figspath'], name=p['name']+'_unit{:03d}'.format(unit))
            fig.close()
        print("[ {}.do ] {} units processed.".format(THIS, len(units)))

    #-------------------------------------------------------------------------------------
    # Regress
    #-------------------------------------------------------------------------------------

    elif action == 'regress':
        trialsfile = get_trialsfile(p)
        sortedfile = get_sortedfile(p)
        betafile   = get_betafile(p)
        regress(trialsfile, sortedfile, betafile)

    #-------------------------------------------------------------------------------------

    else:
        print("[ {}.do ] Unrecognized action.".format(THIS))