Esempio n. 1
0
def plot_stimulus(plot, rng, coh=+12.8):
    plot.axis_off('bottom')

    x = np.linspace(dt, tmax, int(tmax / dt))
    y_high = np.zeros_like(x)
    y_low = np.zeros_like(x)
    for i in xrange(len(x)):
        if stimulus[0] < x[i] <= stimulus[1]:
            y_high[i] = (1 + coh / 100) / 2 + rng.normal(scale=0.1)
            y_low[i] = (1 - coh / 100) / 2 + rng.normal(scale=0.1)
    y_high += offset
    plot.plot(x[::n_display],
              y_high[::n_display],
              color=Figure.colors('blue'),
              lw=1.5,
              zorder=10)
    plot.plot(x[::n_display],
              y_low[::n_display],
              color=Figure.colors('red'),
              lw=1.5,
              zorder=9)

    plot.xlim(0, tmax)
    plot.ylim(0, 1)
    plot.yticks([0, 1])
Esempio n. 2
0
def plot_inputs(trial, mod, all):
    # Visual input
    r = trial['u'][m.VISUAL_P][w]
    plots[mod + '_v'].plot(t,
                           r,
                           color=multisensory.colors['v'],
                           lw=0.8,
                           zorder=5)
    all.append(r)

    # Auditory input
    r = trial['u'][m.AUDITORY_P][w]
    plots[mod + '_a'].plot(t,
                           r,
                           color=multisensory.colors['a'],
                           lw=0.8,
                           zorder=5)
    all.append(r)

    # Boundaries
    if 'v' in mod:
        plots[mod + '_v'].plot(t,
                               boundary_v * np.ones_like(t),
                               color=Figure.colors('darkblue'),
                               lw=0.75,
                               zorder=10)
    if 'a' in mod:
        plots[mod + '_a'].plot(t,
                               boundary_a * np.ones_like(t),
                               color=Figure.colors('darkgreen'),
                               lw=0.75,
                               zorder=10)

    T = trial['t']
    W, = np.where((500 < T) & (T <= 1500))
    print(np.std(trial['u'][m.VISUAL_P][W]))
    print(np.std(trial['u'][m.AUDITORY_P][W]))
Esempio n. 3
0
def plot_trial(trial_info, trial, figspath, name):
    U, Z, A, R, M, init, states_0, perf = trial_info
    U = U[:,0,:]
    Z = Z[:,0,:]
    A = A[:,0,:]
    R = R[:,0]
    M = M[:,0]
    t = int(np.sum(M))

    w = 0.65
    h = 0.18
    x = 0.17
    dy = h + 0.05
    y0 = 0.08
    y1 = y0 + dy
    y2 = y1 + dy
    y3 = y2 + dy

    fig   = Figure(h=6)
    plots = {'observables': fig.add([x, y3, w, h]),
             'policy':      fig.add([x, y2, w, h]),
             'actions':     fig.add([x, y1, w, h]),
             'rewards':     fig.add([x, y0, w, h])}

    time        = trial['time']
    dt          = time[1] - time[0]
    act_time    = time[:t]
    obs_time    = time[:t-1] + dt
    reward_time = act_time + dt
    xlim        = (0, max(time))

    #-------------------------------------------------------------------------------------
    # Observables
    #-------------------------------------------------------------------------------------

    plot = plots['observables']
    plot.plot(obs_time, U[:t-1,0], 'o', ms=5, mew=0, mfc=Figure.colors('blue'))
    plot.plot(obs_time, U[:t-1,0], lw=1.25, color=Figure.colors('blue'),   label='Keydown')
    plot.plot(obs_time, U[:t-1,1], 'o', ms=5, mew=0, mfc=Figure.colors('orange'))
    plot.plot(obs_time, U[:t-1,1], lw=1.25, color=Figure.colors('orange'), label=r'$f_\text{pos}$')
    plot.plot(obs_time, U[:t-1,2], 'o', ms=5, mew=0, mfc=Figure.colors('purple'))
    plot.plot(obs_time, U[:t-1,2], lw=1.25, color=Figure.colors('purple'), label=r'$f_\text{neg}$')
    plot.xlim(*xlim)
    plot.ylim(0, 1)
    plot.ylabel('Observables')

    if trial['gt_lt'] == '>':
        f1, f2 = trial['fpair']
    else:
        f2, f1 = trial['fpair']
    plot.text_upper_right(str((f1, f2)))

    #coh = trial['left_right']*trial['coh']
    #if coh < 0:
    #    color = Figure.colors('orange')
    #elif coh > 0:
    #    color = Figure.colors('purple')
    #else:
    #    color = Figure.colors('k')
    #plot.text_upper_right('Coh = {:.1f}\%'.format(coh), color=color)

    props = {'prop': {'size': 7}, 'handlelength': 1.2,
             'handletextpad': 1.2, 'labelspacing': 0.8}
    plot.legend(bbox_to_anchor=(1.2, 0.8), **props)

    #-------------------------------------------------------------------------------------
    # Policy
    #-------------------------------------------------------------------------------------

    plot = plots['policy']
    plot.plot(act_time, Z[:t,0], 'o', ms=5, mew=0, mfc=Figure.colors('blue'))
    plot.plot(act_time, Z[:t,0], lw=1.25, color=Figure.colors('blue'),
              label='Keydown')
    plot.plot(act_time, Z[:t,1], 'o', ms=5, mew=0, mfc=Figure.colors('orange'))
    plot.plot(act_time, Z[:t,1], lw=1.25, color=Figure.colors('orange'),
              label='$f_1 > f_2$')
    plot.plot(act_time, Z[:t,2], 'o', ms=5, mew=0, mfc=Figure.colors('purple'))
    plot.plot(act_time, Z[:t,2], lw=1.25, color=Figure.colors('purple'),
              label='$f_1 < f_2$')
    plot.xlim(*xlim)
    plot.ylim(0, 1)
    plot.ylabel('Action probabilities')

    props = {'prop': {'size': 7}, 'handlelength': 1.2,
             'handletextpad': 1.2, 'labelspacing': 0.8}
    plot.legend(bbox_to_anchor=(1.27, 0.8), **props)

    #-------------------------------------------------------------------------------------
    # Actions
    #-------------------------------------------------------------------------------------

    plot = plots['actions']
    actions = [np.argmax(a) for a in A[:t]]
    plot.plot(act_time, actions, 'o', ms=5, mew=0, mfc=Figure.colors('red'))
    plot.plot(act_time, actions, lw=1.25, color=Figure.colors('red'))
    plot.xlim(*xlim)
    plot.ylim(0, 2)
    plot.yticks([0, 1, 2])
    plot.yticklabels(['Keydown', '$f_1 > f_2$', '$f_1 < f_2$'])
    plot.ylabel('Action')

    #-------------------------------------------------------------------------------------
    # Rewards
    #-------------------------------------------------------------------------------------

    plot = plots['rewards']
    plot.plot(reward_time, R[:t], 'o', ms=5, mew=0, mfc=Figure.colors('red'))
    plot.plot(reward_time, R[:t], lw=1.25, color=Figure.colors('red'))
    plot.xlim(*xlim)
    plot.ylim(R_TERMINATE, R_CORRECT)
    plot.xlabel('Time (ms)')
    plot.ylabel('Reward')

    #-------------------------------------------------------------------------------------

    fig.save(path=figspath, name=name)
    fig.close()
Esempio n. 4
0
def do(action, args, p):
    """
    Manage tasks.

    """
    print("ACTION*:   " + str(action))
    print("ARGS*:     " + str(args))

    #-------------------------------------------------------------------------------------
    # Trials
    #-------------------------------------------------------------------------------------

    if action == 'trials':
        run_trials(p, args)

    #-------------------------------------------------------------------------------------
    # Sort
    #-------------------------------------------------------------------------------------

    elif action == 'sort_stim_onset':

        sort_trials(get_trialsfile(p), get_sortedfile_stim_onset(p))

    #-------------------------------------------------------------------------------------
    # activate state
    #-------------------------------------------------------------------------------------

    # TODO plot multiple units in the same figure
    # TODO replace units name with real neurons

    elif action == 'activatestate':

        # Model
        m = p['model']

        # Intensity
        try:
            intensity = float(args[0])
        except:
            intensity = 1

        # Plot unit
        try:
            unit = int(args[1])
            if unit == -1:
                unit = None
        except:
            unit = None

        # Create RNN
        if 'init' in args:
            print("* Initial network.")
            base, ext = os.path.splitext(p['savefile'])
            savefile_init = base + '_init' + ext
            rnn = RNN(savefile_init, {'dt': p['dt']}, verbose=True)
        else:
            rnn = RNN(p['savefile'], {'dt': p['dt']}, verbose=True)

        trial_func = p['model'].generate_trial
        trial_args = {
            'name': 'test',
            'catch': False,
            'intensity': intensity,
        }
        info = rnn.run(inputs=(trial_func, trial_args), seed=p['seed'])

        # Summary
        mean = np.mean(rnn.z)
        std = np.std(rnn.z)
        print("Intensity: {:.6f}".format(intensity))
        print("Mean output: {:.6f}".format(mean))
        print("Std. output: {:.6f}".format(std))

        # Figure setup
        x = 0.12
        y = 0.12
        w = 0.80
        h = 0.80
        dashes = [3.5, 1.5]

        t_forward = 1e-3 * np.array(info['epochs']['forward'])
        t_stimulus = 1e-3 * np.array(info['epochs']['stimulus'])
        t_reversal = 1e-3 * np.array(info['epochs']['reversal'])

        fig = Figure(w=4,
                     h=3,
                     axislabelsize=7,
                     labelpadx=5,
                     labelpady=5,
                     thickness=0.6,
                     ticksize=3,
                     ticklabelsize=6,
                     ticklabelpad=2)
        plots = {
            'in': fig.add([x, y + 0.72 * h, w, 0.3 * h]),
            'out': fig.add([x, y, w, 0.65 * h]),
        }

        plot = plots['in']
        plot.ylabel('Input', labelpad=7, fontsize=6.5)

        plot = plots['out']
        plot.xlabel('Time (sec)', labelpad=6.5)
        plot.ylabel('Output', labelpad=7, fontsize=6.5)

        # -----------------------------------------------------------------------------------------
        # Input
        # -----------------------------------------------------------------------------------------

        plot = plots['in']
        plot.axis_off('bottom')

        plot.plot(1e-3 * rnn.t, rnn.u[0], color=Figure.colors('red'), lw=0.5)
        plot.lim('y', rnn.u[0])
        plot.xlim(1e-3 * rnn.t[0], 1e-3 * rnn.t[-1])

        # -----------------------------------------------------------------------------------------
        # Output
        # -----------------------------------------------------------------------------------------

        plot = plots['out']

        # Outputs
        colors = [Figure.colors('orange'), Figure.colors('blue')]
        if unit is None:
            plot.plot(1e-3 * rnn.t,
                      rnn.z[0],
                      color=colors[0],
                      label='Forward module')
            plot.plot(1e-3 * rnn.t,
                      rnn.z[1],
                      color=colors[1],
                      label='Reversal module')
            plot.lim('y', np.ravel(rnn.z), lower=0)
        else:
            plot.plot(1e-3 * rnn.t,
                      rnn.r[unit],
                      color=colors[1],
                      label='unit ' + str(unit))
            plot.lim('y', np.ravel(rnn.r[unit]))

        plot.xlim(1e-3 * rnn.t[0], 1e-3 * rnn.t[-1])

        # Legend
        props = {'prop': {'size': 7}}
        plot.legend(bbox_to_anchor=(1.1, 1.6), **props)

        plot.vline(t_forward[-1],
                   color='0.2',
                   linestyle='--',
                   lw=1,
                   dashes=dashes)
        plot.vline(t_reversal[0],
                   color='0.2',
                   linestyle='--',
                   lw=1,
                   dashes=dashes)

        # Epochs
        plot.text(np.mean(t_forward),
                  plot.get_ylim()[1],
                  'forward',
                  ha='center',
                  va='center',
                  fontsize=7)
        plot.text(np.mean(t_stimulus),
                  plot.get_ylim()[1],
                  'stimulus',
                  ha='center',
                  va='center',
                  fontsize=7)
        plot.text(np.mean(t_reversal),
                  plot.get_ylim()[1],
                  'reversal',
                  ha='center',
                  va='center',
                  fontsize=7)

        if 'init' in args:
            savename = p['name'] + '_' + action + '_init'
        else:
            savename = p['name'] + '_' + action

        if unit is not None:
            savename += '_unit_' + str(unit)

        fig.save(path=p['figspath'], name=savename)
        fig.close()

    # -------------------------------------------------------------------------------------
    # Plot single-unit activity aligned to stimulus onset
    # -------------------------------------------------------------------------------------

    elif action == 'units_stim_onset':
        from glob import glob

        try:
            lower_bon = float(args[0])
        except:
            lower_bon = None

        try:
            higher_bon = float(args[1])
        except:
            higher_bon = None

        # Remove existing files
        unitpath = join(p['figspath'], 'units')
        filenames = glob(join(unitpath, p['name'] + '_stim_onset_unit*'))
        for filename in filenames:
            os.remove(filename)
            print("Removed {}".format(filename))

        # Load sorted trials
        sortedfile = get_sortedfile_stim_onset(p)
        with open(sortedfile) as f:
            t, sorted_trials = pickle.load(f)

        rnn = RNN(p['savefile'], {'dt': p['dt']}, verbose=True)
        trial_func = p['model'].generate_trial
        trial_args = {
            'name': 'test',
            'catch': False,
        }
        info = rnn.run(inputs=(trial_func, trial_args), seed=p['seed'])

        t_stimulus = np.array(info['epochs']['stimulus'])
        stimulus_d = t_stimulus[1] - t_stimulus[0]

        for i in xrange(p['model'].N):
            # Check if the unit does anything
            # active = False
            # for r in sorted_trials.values():
            #     if is_active(r[i]):
            #         active = True
            #         break
            # if not active:
            #     continue

            dashes = [3.5, 1.5]

            fig = Figure()
            plot = fig.add()

            # -----------------------------------------------------------------------------
            # Plot
            # -----------------------------------------------------------------------------

            plot_unit(i, sortedfile, plot, tmin=lower_bon, tmax=higher_bon)

            plot.xlabel('Time (ms)')
            plot.ylabel('Firing rate (a.u.)')

            props = {
                'prop': {
                    'size': 8
                },
                'handletextpad': 1.02,
                'labelspacing': 0.6
            }
            plot.legend(bbox_to_anchor=(0.18, 1), **props)

            plot.vline(0, color='0.2', linestyle='--', lw=1, dashes=dashes)
            plot.vline(stimulus_d,
                       color='0.2',
                       linestyle='--',
                       lw=1,
                       dashes=dashes)

            # Epochs
            plot.text(-np.mean((0, stimulus_d)),
                      plot.get_ylim()[1],
                      'forward',
                      ha='center',
                      va='center',
                      fontsize=7)
            plot.text(np.mean((0, stimulus_d)),
                      plot.get_ylim()[1],
                      'stimulus',
                      ha='center',
                      va='center',
                      fontsize=7)
            plot.text(3 * np.mean((0, stimulus_d)),
                      plot.get_ylim()[1],
                      'reversal',
                      ha='center',
                      va='center',
                      fontsize=7)

            # -----------------------------------------------------------------------------

            fig.save(path=unitpath,
                     name=p['name'] + '_stim_onset_unit{:03d}'.format(i))
            fig.close()

    #-------------------------------------------------------------------------------------
    # Selectivity
    #-------------------------------------------------------------------------------------

    elif action == 'selectivity':

        try:
            lower = float(args[0])
        except:
            lower = None

        try:
            higher = float(args[1])
        except:
            higher = None

        # Model
        m = p['model']

        trialsfile = get_trialsfile(p)
        dprime = get_choice_selectivity(trialsfile,
                                        lower_bon=lower,
                                        higher_bon=higher)

        def get_first(x, p):
            return x[:int(p * len(x))]

        psig = 0.25
        units = np.arange(len(dprime))
        try:
            idx = np.argsort(abs(dprime[m.EXC]))[::-1]
            exc = get_first(units[m.EXC][idx], psig)

            idx = np.argsort(abs(dprime[m.INH]))[::-1]
            inh = get_first(units[m.INH][idx], psig)

            idx = np.argsort(dprime[exc])[::-1]
            units_exc = list(exc[idx])

            idx = np.argsort(dprime[inh])[::-1]
            units_inh = list(units[inh][idx])

            units = units_exc + units_inh
            dprime = dprime[units]
        except AttributeError:
            idx = np.argsort(abs(dprime))[::-1]
            all = get_first(units[idx], psig)

            idx = np.argsort(dprime[all])[::-1]
            units = list(units[all][idx])
            dprime = dprime[units]

        # Save d'
        filename = get_dprimefile(p)
        np.savetxt(filename, dprime)
        print("[ {}.do ] d\' saved to {}".format(THIS, filename))

        # Save selectivity
        filename = get_selectivityfile(p)
        np.savetxt(filename, units, fmt='%d')
        print("[ {}.do ] Choice selectivity saved to {}".format(
            THIS, filename))

    #-------------------------------------------------------------------------------------

    else:
        print("[ {}.do ] Unrecognized action.".format(THIS))
Esempio n. 5
0
    return join(p['trialspath'], p['name'] + '_sorted.pkl')


# Simple choice function
def get_choice(trial):
    return np.argmax(trial['z'][:, -1])


# Define "active" units
def is_active(r):
    return np.std(r) > 0.05


# Colors
colors = {
    'v': Figure.colors('blue'),
    'a': Figure.colors('green'),
    'va': Figure.colors('orange')
}

#=========================================================================================


def run_trials(p, args):
    """
    Run trials.

    """
    # Model
    m = p['model']
Esempio n. 6
0
plot.xlabel('Time from decision (ms)')

#-----------------------------------------------------------------------------------------
# Variable stimulus duration
#-----------------------------------------------------------------------------------------

plot = plots['A']

t_fixation = np.array([0, 200])
t_stimulus = np.array([200, 800])
t_decision = np.array([800, 1000])

hi = 1
lo = 0.1
plot.plot(t_fixation, (lo + shift) * np.ones_like(t_fixation),
          color=Figure.colors('blue'),
          lw=2)
plot.plot(t_fixation, (lo - shift) * np.ones_like(t_fixation),
          color=Figure.colors('red'),
          lw=2)
plot.plot(t_decision,
          hi * np.ones_like(t_decision),
          color=Figure.colors('blue'),
          lw=2)
plot.plot(t_decision, (lo - shift) * np.ones_like(t_decision),
          color=Figure.colors('red'),
          lw=2)

plot.xlim(0, t_decision[-1])
plot.xticks()
Esempio n. 7
0
    'catch':  False,
    'coh':    16,
    'in_out': 1
    }
info = rnn.run(inputs=(trial_func, trial_args), seed=10)

colors = ['orange', 'purple']

DT = 15

# Inputs
for i, clr in enumerate(colors):
    fig  = Figure(w=2, h=1)
    plot = fig.add(None, 'none')

    plot.plot(rnn.t[::DT], rnn.u[i][::DT], color=Figure.colors(clr), lw=1)
    plot.ylim(0, 1.5)

    fig.save(path=figspath, name='fig1a_input{}'.format(i+1))
    fig.close()

# Outputs
for i, clr in enumerate(colors):
    fig  = Figure(w=2, h=1)
    plot = fig.add(None, 'none')

    plot.plot(rnn.t[::DT], rnn.z[i][::DT], color=Figure.colors(clr), lw=1)
    plot.ylim(0, 1.5)

    fig.save(path=figspath, name='fig1a_output{}'.format(i+1))
    fig.close()
Esempio n. 8
0
def plot_unit(unit,
              sortedfile,
              plots,
              t0=0,
              tmin=-np.inf,
              tmax=np.inf,
              **kwargs):
    # Load sorted trials
    with open(sortedfile, 'rb') as f:
        t, sorted_trials = pickle.load(f)

    #-------------------------------------------------------------------------------------
    # Labels
    #-------------------------------------------------------------------------------------

    # Unit no.
    fontsize = kwargs.get('unit_fontsize', 7)
    plots['choice'].text_upper_center('Unit ' + str(unit),
                                      dy=0.07,
                                      fontsize=fontsize)

    # Sort-by
    if kwargs.get('sortby_fontsize') is not None:
        fontsize = kwargs['sortby_fontsize']
        labels = {
            'choice': 'choice',
            'motion_choice': 'motion \& choice',
            'colour_choice': 'color \& choice',
            'context_choice': 'context \& choice'
        }
        for k, label in labels.items():
            plots[k].ylabel(label)

    #-------------------------------------------------------------------------------------
    # Setup
    #-------------------------------------------------------------------------------------

    # Duration to plot
    w, = np.where((tmin <= t) & (t <= tmax))
    t = t - t0

    # Linestyle
    def get_linestyle(choice):
        if choice == +1:
            return '-'
        return '--'

    # Line width
    lw = kwargs.get('lw', 1)

    # For setting axis limits
    yall = []

    #-------------------------------------------------------------------------------------
    # Choice
    #-------------------------------------------------------------------------------------

    plot = plots['choice']
    condition_averaged = sorted_trials['choice']

    for (choice, ), r in condition_averaged.items():
        ls = get_linestyle(choice)
        plot.plot(t[w], r[unit, w], ls, color=Figure.colors('red'), lw=lw)
        yall.append(r[unit, w])
    plot.xlim(t[w][0], t[w][-1])
    plot.xticks([t[w][0], 0, t[w][-1]])

    #-------------------------------------------------------------------------------------
    # Motion & choice
    #-------------------------------------------------------------------------------------

    plot = plots['motion_choice']
    condition_averaged = sorted_trials['motion_choice']

    abscohs = []
    for (choice, coh, context) in condition_averaged:
        abscohs.append(abs(coh))
    abscohs = sorted(list(set(abscohs)))

    for (choice, coh, context), r in condition_averaged.items():
        if context != 'm':
            continue

        ls = get_linestyle(choice)

        idx = abscohs.index(abs(coh))
        basecolor = 'k'
        if idx == 0:
            color = apply_alpha(basecolor, 0.4)
        elif idx == 1:
            color = apply_alpha(basecolor, 0.7)
        else:
            color = apply_alpha(basecolor, 1)

        plot.plot(t[w], r[unit, w], ls, color=color, lw=lw)
        yall.append(r[unit, w])
    plot.xlim(t[w][0], t[w][-1])
    plot.xticks([t[w][0], 0, t[w][-1]])

    #-------------------------------------------------------------------------------------
    # Colour & choice
    #-------------------------------------------------------------------------------------

    plot = plots['colour_choice']
    condition_averaged = sorted_trials['colour_choice']

    abscohs = []
    for (choice, coh, context) in condition_averaged:
        abscohs.append(abs(coh))
    abscohs = sorted(list(set(abscohs)))

    for (choice, coh, context), r in condition_averaged.items():
        if context != 'c':
            continue

        ls = get_linestyle(choice)

        idx = abscohs.index(abs(coh))
        basecolor = Figure.colors('darkblue')
        if idx == 0:
            color = apply_alpha(basecolor, 0.4)
        elif idx == 1:
            color = apply_alpha(basecolor, 0.7)
        else:
            color = apply_alpha(basecolor, 1)

        plot.plot(t[w], r[unit, w], ls, color=color, lw=lw)
        yall.append(r[unit, w])
    plot.xlim(t[w][0], t[w][-1])
    plot.xticks([t[w][0], 0, t[w][-1]])

    #-------------------------------------------------------------------------------------
    # Context & choice
    #-------------------------------------------------------------------------------------

    plot = plots['context_choice']
    condition_averaged = sorted_trials['context_choice']

    for (choice, context), r in condition_averaged.items():
        ls = get_linestyle(choice)

        if context == 'm':
            color = 'k'
        else:
            color = Figure.colors('darkblue')

        plot.plot(t[w], r[unit, w], ls, color=color, lw=lw)
        yall.append(r[unit, w])
    plot.xlim(t[w][0], t[w][-1])
    plot.xticks([t[w][0], 0, t[w][-1]])

    return yall
Esempio n. 9
0
        savefile_init = base + '_init' + ext
        rnn = RNN(savefile_init, {'dt': dt}, verbose=True)
    else:
        rnn = RNN(savefile, {'dt': dt}, verbose=True)
    rnn.run(3e3, seed=seed)

    # Summary
    mean = np.mean(rnn.z)
    std = np.std(rnn.z)
    print("Mean output: {:.6f}".format(mean))
    print("Std. output: {:.6f}".format(std))

    fig = Figure()
    plot = fig.add()

    colors = [Figure.colors('blue'), Figure.colors('orange')]
    for i in xrange(rnn.z.shape[0]):
        plot.plot(1e-3 * rnn.t, rnn.z[i], color=colors[i % len(colors)])
        mean = np.mean(rnn.z[i]) * np.ones_like(rnn.t)
        plot.plot(1e-3 * rnn.t, mean, color=colors[i % len(colors)])
    plot.xlim(1e-3 * rnn.t[0], 1e-3 * rnn.t[-1])
    plot.lim('y', np.ravel(rnn.z), lower=0)

    plot.xlabel('Time (sec)')
    plot.ylabel('Outputs')

    fig.save(path=figspath, name=name + '_' + action)
    fig.close()

#=========================================================================================
# Plot network structure
Esempio n. 10
0
x0 = 0.02
x1 = x0 + dx2
x2 = x1 + dx2
x_mask = 0.525

y = 0.93
plotlabels = {'A': (x0, y), 'B': (x1, y), 'C': (x2, y)}
fig.plotlabels(plotlabels, fontsize=paper.plotlabelsize)

#=========================================================================================
# Create color maps for weights
#=========================================================================================

# Colors
white = 'w'
blue = Figure.colors('strongblue')
red = Figure.colors('strongred')


def generate_cmap(Ws):
    exc = []
    inh = []
    for W in Ws:
        exc.append(np.ravel(W[np.where(W > 0)]))
        inh.append(-np.ravel(W[np.where(W < 0)]))

    exc = np.sort(np.concatenate(exc))
    inh = np.sort(np.concatenate(inh))

    exc_ignore = int(0.1 * len(exc))
    inh_ignore = int(0.1 * len(inh))
Esempio n. 11
0
                       dy=0.06, fontsize=7)
plot.xaxis.set_label_position('top')
plot.xlabel('From', labelpad=4)
plot.ylabel('To', labelpad=4)

plot = plots['Brec']

units  = m.EXC_SENSORY
label  = '"Sensory" area (exc)'

groups  = [m.EXC_SENSORY, m.EXC_MOTOR, m.INH_SENSORY, m.INH_MOTOR]
labels  = [r"``Sensory'' area ($\text{E}_\text{S}$)",
           r"``Motor'' area ($\text{E}_\text{M}$)",
           r"($\text{I}_\text{S}$)",
           r"($\text{I}_\text{M}$)"]
colors  = [Figure.colors('green'), Figure.colors('orange')]
colors += colors
for group, label, color in zip(groups, labels, colors):
    extent = (np.array([group[0], group[-1]]) + 0.5)/m.N
    plot.plot(extent, 1.03*np.ones(2), lw=2, color=color, transform=plot.transAxes)
    plot.text(np.mean(extent), 1.05, label, ha='center', va='bottom',
              fontsize=6, color=color, transform=plot.transAxes)

    plots['Bin'].plot(1.2*np.ones(2), 1-extent, lw=2, color=color,
                      transform=plot.transAxes)

E_S = m.EXC_SENSORY
E_M = m.EXC_MOTOR
x_S = np.mean([E_S[0], E_S[-1]])
x_M = np.mean([E_M[0], E_M[-1]])
Esempio n. 12
0
plot.ylabel('Min. percent correct')

plot = plots[models[-1][0]]
plot.ylabel('Error in eye position')

for k in xrange(len(models)):
    model, desc = models[k]
    if desc is None:
        desc = model
    plots[model].text_upper_center(desc, dy=0.05, fontsize=6.5)

#=========================================================================================
# Plot performance
#=========================================================================================

clr_target = Figure.colors('red')
clr_actual = '0.2'
clr_seeds = '0.8'

for model, _ in models:
    plot = plots[model]

    try:
        rnn = RNN(get_savefile(model), verbose=True)
    except SystemExit:
        continue

    xall = []

    ntrials = [int(costs[0]) for costs in rnn.costs_history]
    ntrials = np.asarray(ntrials, dtype=int) / int(1e4)
Esempio n. 13
0
# Stimulus durations
epochs = info['epochs']
f1_start, f1_end = epochs['f1']
f2_start, f2_end = epochs['f2']
t0   = f1_start
tmin = 0
tmax = f2_end

t     = 1e-3*(rnn.t-t0)
delay = [1e-3*(f1_end-t0), 1e-3*(f2_start-t0)]
yall  = []

# f1 > f2
plot = plots['>']
plot.plot(t, rnn.u[0], color=Figure.colors('orange'), lw=0.5)
yall.append(rnn.u[0])
plot.xticklabels()

trial_args = {
    'name':  'test',
    'catch': False,
    'fpair': (34, 26),
    'gt_lt': '<'
    }
info = rnn.run(inputs=(trial_func, trial_args), rng=rng)

# f1 < f2
plot = plots['<']
plot.plot(t, rnn.u[0], color=Figure.colors('purple'), lw=0.5)
yall.append(rnn.u[0])
Esempio n. 14
0
def tuning_corr(trialsfile,
                sortedfile,
                plot_sig,
                plot_corr=None,
                plot_stim=None,
                plot_delay=None,
                t0=0,
                **kwargs):
    """
    Plot correlation of a1 between different times.

    """
    units, a1s, pvals = tuning(sortedfile)

    # Get first trial
    trials, _ = load_trials(trialsfile)
    trial = trials[0]
    info = trial['info']
    t = trial['t']

    _, delay_end = info['epochs']['delay']
    delay_end = 1e-3 * (delay_end - t0)

    t_stim = np.mean(info['epochs']['f1'])
    idx_stim = np.where(t >= t_stim)[0][0]

    t_delay = np.mean(info['epochs']['delay'])
    idx_delay = np.where(t >= t_delay)[0][0]

    t_delay_end = info['epochs']['delay'][1]
    idx_delay_end = np.where(t > t_delay_end)[0][0]

    t_f2_end = info['epochs']['f2'][1]
    idx_f2_end = np.where(t > t_f2_end)[0][0]

    t_all = 1e-3 * (t[idx_stim:idx_delay_end] - t0)
    idx_all = np.arange(len(t))[idx_stim:idx_delay_end]

    # Plot correlation across time
    if plot_corr is not None:
        plot = plot_corr

        # With stimulus period
        corr = []
        for k in idx_all:
            corr.append(stats.pearsonr(a1s[:, idx_stim], a1s[:, k])[0])
        plot.plot(t_all,
                  corr,
                  color=Figure.colors('blue'),
                  lw=kwargs.get('lw', 1))

        # With stimulus period
        corr = []
        for k in idx_all:
            corr.append(stats.pearsonr(a1s[:, idx_delay], a1s[:, k])[0])
        plot.plot(t_all,
                  corr,
                  color=Figure.colors('green'),
                  lw=kwargs.get('lw', 1))

        plot.xlim(-1e-3 * t0, delay_end)
        plot.ylim(-1, 1)

    # Plot fraction of significantly tuned units.
    if plot_sig is not None:
        plot = plot_sig

        psig = np.sum(1 * (pvals < 0.05), axis=0) / len(units)
        plot.plot(1e-3 * (t[1:idx_f2_end] - t0),
                  psig[1:idx_f2_end],
                  color='0.2',
                  lw=kwargs.get('lw', 1))

        plot.xlim(1e-3 * (t[0] - t0), 1e-3 * (t[idx_f2_end - 1] - t0))
        #plot.lim('y', psig, lower=0)
        plot.ylim(0, 1)

    # Shared plot properties
    prop = {'mfc': '0.2', 'mec': 'none', 'ms': kwargs.get('ms', 2)}

    # Plot a1, end of delay vs. stimulus
    if plot_stim is not None:
        plot = plot_stim
        plot.equal()

        for i in xrange(len(units)):
            plot.plot(a1s[i, idx_stim], a1s[i, idx_delay_end], 'o', **prop)

    # Plot a1, end of delay vs. middle of delay
    if plot_delay is not None:
        plot = plot_delay
        plot.equal()

        for i in xrange(len(units)):
            plot.plot(a1s[i, idx_delay], a1s[i, idx_delay_end], 'o', **prop)
Esempio n. 15
0
    'labelspacing': 0.5
}
plots['psy_m'].legend(bbox_to_anchor=(0.58, 1.12), **prop)

#=========================================================================================
# State space
#=========================================================================================

mante_plots = {s: plots[s] for s in ['m1', 'm2', 'm3', 'c1', 'c2', 'c3']}
mante.plot_statespace(trialsfile, sortedfile, betafile, mante_plots)

plots['m2'].text_upper_center('Motion context', dy=0.08, fontsize=7, color='k')
plots['c2'].text_upper_center('Color context',
                              dy=0.08,
                              fontsize=7,
                              color=Figure.colors('darkblue'))

#-----------------------------------------------------------------------------------------
# Legend
#-----------------------------------------------------------------------------------------

ms_filled = 2.5
ms_empty = 2.5

mew_filled = 0.5
mew_empty = 0.5

y = 1.2
dx = 0.08
dy = 0.06
Esempio n. 16
0
def psychometric_function(trialsfile, plots=None, **kwargs):
    """
    Psychometric function.

    """
    # Load trials
    trials, ntrials = load_trials(trialsfile)

    #-------------------------------------------------------------------------------------
    # Compute psychometric function
    #-------------------------------------------------------------------------------------

    results = {cond: {} for cond in ['mm', 'mc', 'cm', 'cc']}
    ncorrect = 0
    for trial in trials:
        info = trial['info']
        coh_m = info['left_right_m'] * info['coh_m']
        coh_c = info['left_right_c'] * info['coh_c']
        choice = get_choice(trial)

        if choice == info['choice']:
            ncorrect += 1

        if info['context'] == 'm':
            results['mm'].setdefault(coh_m, []).append(choice)
            results['mc'].setdefault(coh_c, []).append(choice)
        else:
            results['cm'].setdefault(coh_m, []).append(choice)
            results['cc'].setdefault(coh_c, []).append(choice)
    print("[ {}.psychometric_function ] {:.2f}% correct.".format(
        THIS, 100 * ncorrect / ntrials))

    for cond in results:
        choice_by_coh = results[cond]

        cohs = np.sort(np.array(choice_by_coh.keys()))
        p0 = np.zeros(len(cohs))
        for i, coh in enumerate(cohs):
            choices = np.array(choice_by_coh[coh])
            p0[i] = 1 - np.sum(choices) / len(choices)
        scaled_cohs = SCALE * cohs

        results[cond] = (scaled_cohs, p0)

    #-------------------------------------------------------------------------------------
    # Plot
    #-------------------------------------------------------------------------------------

    if plots is not None:
        ms = kwargs.get('ms', 5)
        color_m = '0.2'
        color_c = Figure.colors('darkblue')

        for cond, result in results.items():
            # Context
            if cond[0] == 'm':
                color = color_m
                label = 'Motion context'
            else:
                color = color_c
                label = 'Color context'

            # Stimulus
            if cond[1] == 'm':
                plot = plots['m']
            else:
                plot = plots['c']

            # Result
            scaled_cohs, p0 = result

            # Data points
            plot.plot(scaled_cohs,
                      100 * p0,
                      'o',
                      ms=ms,
                      mew=0,
                      mfc=color,
                      zorder=10)

            # Fit
            try:
                popt, func = fittools.fit_psychometric(scaled_cohs, p0)

                fit_cohs = np.linspace(min(scaled_cohs), max(scaled_cohs), 201)
                fit_p0 = func(fit_cohs, **popt)
                plot.plot(fit_cohs,
                          100 * fit_p0,
                          color=color,
                          lw=1,
                          zorder=5,
                          label=label)
            except RuntimeError:
                print("[ {}.psychometric_function ]".format(THIS) +
                      " Unable to fit, drawing a line through the points.")
                plot.plot(scaled_cohs,
                          100 * p0,
                          color=color,
                          lw=1,
                          zorder=5,
                          label=label)

            plot.lim('x', scaled_cohs)
            plot.ylim(0, 100)

    #-------------------------------------------------------------------------------------

    return results
Esempio n. 17
0
rng = np.random.RandomState(1066)
rnn = RNN('romo_savefile.pkl', {'dt': 2})

trial_args = {'name': 'test', 'catch': False, 'fpair': (34, 26), 'gt_lt': '>'}

info = rnn.run(inputs=(m.generate_trial, trial_args), rng=rng)

fig = Figure()
plot = fig.add()

epochs = info['epochs']
f1_start, f1_end = epochs['f1']
f2_start, f2_end = epochs['f2']
t0 = f1_start
tmin = 0
tmax = f2_end

t = 1e-3 * (rnn.t - t0)
delay = [1e-3 * (f1_end - t0), 1e-3 * (f2_start - t0)]
yall = []

plot.plot(t, rnn.u[0], color=Figure.colors('orange'), lw=0.5)
plot.plot(t, rnn.u[1], color=Figure.colors('blue'), lw=0.5)
plot.plot(t, rnn.z[0], color=Figure.colors('red'), lw=0.5)
plot.plot(t, rnn.z[1], color=Figure.colors('green'), lw=0.5)

plot.xlabel('time')
plot.ylabel('output')

fig.save(path='.', name='Romo_Figure')
Esempio n. 18
0
def plot_statespace(trialsfile, sortedfile, betafile, plots):
    # Load trials
    trials, ntrials = load_trials(trialsfile)

    # Load sorted trials
    with open(sortedfile) as f:
        t, sorted_trials = pickle.load(f)

    # Load task axes
    with open(betafile) as f:
        M = pickle.load(f).T

    # Active units
    units = get_active_units(trialsfile)

    # Epoch to plot
    start, end = trials[0]['info']['epochs']['stimulus']
    start += 0
    end += 0
    w, = np.where((start <= t) & (t <= end))

    # Down-sample
    dt = t[1] - t[0]
    step = int(50 / dt)
    w = w[::step]

    # Colors
    color_m = 'k'
    color_c = Figure.colors('darkblue')

    xall = []
    yall = []

    #-------------------------------------------------------------------------------------
    # Labels
    #-------------------------------------------------------------------------------------

    plots['c1'].xlabel('Choice')

    #-------------------------------------------------------------------------------------
    # Motion context: motion vs. choice, sorted by coherence
    #-------------------------------------------------------------------------------------

    plot = plots['m1']

    p_vc = {}
    for cond, r in sorted_trials['motion_choice'].items():
        if cond[2] == 'm':
            p_vc[cond] = M.dot(r[units, :][:, w])
    x, y = plot_taskaxes(plot, MOTION, p_vc, color_m)
    xall.append(x)
    yall.append(y)

    plot.ylabel('Motion')

    #-------------------------------------------------------------------------------------
    # Motion context: motion vs. choice, sorted by coherence
    #-------------------------------------------------------------------------------------

    plot = plots['m2']
    p_vc = {}
    for cond, r in sorted_trials['motion_choice'].items():
        if cond[2] == 'm':
            p_vc[cond] = M.dot(r[units, :][:, w])
    x, y = plot_taskaxes(plot, COLOUR, p_vc, color_m)
    xall.append(x)
    yall.append(y)

    #-------------------------------------------------------------------------------------
    # Motion context: colour vs. choice, sorted by colour
    #-------------------------------------------------------------------------------------

    plot = plots['m3']
    p_vc = {}
    for cond, r in sorted_trials['colour_choice'].items():
        if cond[2] == 'm':
            p_vc[cond] = M.dot(r[units, :][:, w])
    x, y = plot_taskaxes(plot, COLOUR, p_vc, color_c)
    xall.append(x)
    yall.append(y)

    #-------------------------------------------------------------------------------------
    # Colour context: motion vs. choice, sorted by motion
    #-------------------------------------------------------------------------------------

    plot = plots['c1']
    p_vc = {}
    for cond, r in sorted_trials['motion_choice'].items():
        if cond[2] == 'c':
            p_vc[cond] = M.dot(r[units, :][:, w])
    x, y = plot_taskaxes(plot, MOTION, p_vc, color_m)
    xall.append(x)
    yall.append(y)

    #-------------------------------------------------------------------------------------
    # Colour context: motion vs. choice, sorted by colour
    #-------------------------------------------------------------------------------------

    plot = plots['c2']
    p_vc = {}
    for cond, r in sorted_trials['colour_choice'].items():
        if cond[2] == 'c':
            p_vc[cond] = M.dot(r[units, :][:, w])
    x, y = plot_taskaxes(plot, MOTION, p_vc, color_c)
    xall.append(x)
    yall.append(y)

    #-------------------------------------------------------------------------------------
    # Colour context: colour vs. choice, sorted by colour
    #-------------------------------------------------------------------------------------

    plot = plots['c3']
    p_vc = {}
    for cond, r in sorted_trials['colour_choice'].items():
        if cond[2] == 'c':
            p_vc[cond] = M.dot(r[units, :][:, w])
    x, y = plot_taskaxes(plot, COLOUR, p_vc, color_c)
    xall.append(x)
    yall.append(y)

    #-------------------------------------------------------------------------------------
    # Shared axes
    #-------------------------------------------------------------------------------------

    xall = np.concatenate(xall)
    yall = np.concatenate(yall)

    for plot in plots.values():
        plot.aspect(1.5)
        plot.lim('x', xall)
        plot.lim('y', yall)
Esempio n. 19
0
                  mode=mode,
                  n_validation=n_validation,
                  min_error=min_error)
    model.train('savefile.pkl', seed=100, recover=False)

    #-------------------------------------------------------------------------------------
    # Plot
    #-------------------------------------------------------------------------------------

    from pycog import RNN
    from pycog.figtools import Figure

    rnn = RNN('savefile.pkl', {'dt': 0.5, 'var_rec': 0.01**2})
    info = rnn.run(T=2 * period)

    fig = Figure()
    plot = fig.add()

    plot.plot(rnn.t / tau, rnn.z[0], color=Figure.colors('blue'))
    plot.xlim(rnn.t[0] / tau, rnn.t[-1] / tau)
    plot.ylim(0, 2)

    print rnn.t[0]
    print rnn.t[-1]
    plot.plot((rnn.t / tau)[:], (0.9 * np.power(rnn.t / (2 * period), 2))[:],
              color=Figure.colors('orange'))

    plot.xlabel(r'$t/\tau$')
    plot.ylabel('$\sin t$')

    fig.save(path='.', name='xSquared')
Esempio n. 20
0
plot = plots['SL-e']
plot_epochs(plot)

plot = plots['SL-o']
plot.axis_off('bottom')

x = np.linspace(0, tmax, int(tmax / dt) + 1)
y_high = np.zeros_like(x)
y_low = np.zeros_like(x)
for i in xrange(len(x)):
    if decision[0] < x[i] <= decision[1]:
        y_high[i] = 1
y_high += offset
plot.plot(x,
          y_high,
          color=Figure.colors('blue'),
          lw=2,
          zorder=10,
          label='Left')
plot.plot(x, y_low, color=Figure.colors('red'), lw=2, zorder=9, label='Right')

plot.xlim(0, tmax)
plot.ylim(0, 1)
plot.yticks([0, 1])
plot.ylabel('Target outputs')

# Legend
props = {
    'prop': {
        'size': 7
    },
Esempio n. 21
0
if __name__ == '__main__':
    from pycog          import RNN
    from pycog.figtools import Figure

    rnn  = RNN('work/data/multi_sequence3mod/multi_sequence3mod.pkl', {'dt': 0.5, 'var_rec': 0.01**2,
        'var_in':  np.array([0.003**2])})
    trial_args = {}
    info = rnn.run(inputs=(generate_trial, trial_args), seed=7449)

    fig  = Figure()
    plot = fig.add()

    colors = ['red', 'green', 'yellow', 'orange', 'purple', 'cyan', 'magenta', 'pink']

    plot.plot(rnn.t/tau, rnn.u[0], color=Figure.colors('blue'))
    for i in range(Nout):
        plot.plot(rnn.t/tau, rnn.z[i], color=Figure.colors(colors[int(i / NoutSplit)]))
    plot.xlim(rnn.t[0]/tau, rnn.t[-1]/tau)
    plot.ylim(0, 15)

    Nexc = int(N * 0.8)
    plot.plot(rnn.t/tau, rnn.r[:Nexc].mean(axis=0), color=Figure.colors("pink"))
    plot.plot(rnn.t/tau, rnn.r[Nexc:].mean(axis=0), color=Figure.colors("magenta"))

    plot.xlabel(r'$t/\tau$')
    plot.ylabel(r'$t/\tau$')

    fig.save(path='.', name='multi_sequence3mod')

    fig  = Figure()
Esempio n. 22
0
    rnn = RNN('work/data/multi_sequence8mod/multi_sequence8mod.pkl', {
        'dt': 0.5,
        'var_rec': 0.01**2,
        'var_in': np.array([0.003**2])
    })
    trial_args = {}
    info = rnn.run(inputs=(generate_trial, trial_args), seed=7423)

    fig = Figure()
    plot = fig.add()

    colors = [
        'red', 'green', 'yellow', 'orange', 'purple', 'cyan', 'magenta', 'pink'
    ]

    plot.plot(rnn.t / tau, rnn.u[0], color=Figure.colors('blue'))
    for i in range(Nout):
        plot.plot(rnn.t / tau,
                  rnn.z[i],
                  color=Figure.colors(colors[int(i / NoutSplit)]))

    Nexc = int(N * 0.8)
    plot.plot(rnn.t / tau,
              rnn.r[:Nexc].mean(axis=0),
              color=Figure.colors("pink"))
    plot.plot(rnn.t / tau,
              rnn.r[Nexc:].mean(axis=0),
              color=Figure.colors("magenta"))

    plot.xlim(rnn.t[0] / tau, rnn.t[-1] / tau)
    plot.ylim(0, 15)
Esempio n. 23
0
if __name__ == '__main__':
    from pycog          import RNN
    from pycog.figtools import Figure

    rnn  = RNN('work/data/multi_sequence4mod/multi_sequence4mod.pkl', {'dt': 0.5, 'var_rec': 0.01**2,
        'var_in':  np.array([0.003**2])})
    trial_args = {}
    info = rnn.run(inputs=(generate_trial, trial_args), seed=72)

    fig  = Figure()
    plot = fig.add()

    colors = ['red', 'green', 'yellow', 'orange', 'purple', 'cyan', 'magenta', 'pink']

    plot.plot(rnn.t/tau, rnn.u[0], color=Figure.colors('blue'), label='$input$')
    for i in range(Nout):
        if i % NoutSplit == 0:
            k = {'label': '$output%d$'%int(i / NoutSplit)}
        else:
            k = {}
        plot.plot(rnn.t/tau, rnn.z[i], color=Figure.colors(colors[int(i / NoutSplit)]), **k)

    Nexc = int(N * 0.8)
    np.savetxt("rnnt.txt", rnn.t/tau)
    np.savetxt("r4.txt", np.divide(rnn.r[:Nexc].mean(axis=0), rnn.r[Nexc:].mean(axis=0)))

    plot.xlim(rnn.t[0]/tau, rnn.t[-1]/tau)
    plot.ylim(0, 15)

    prop = {'prop': {'size': 20}, 'handlelength': 1.2,