示例#1
0
def do(action, args, config):
    """
    Manage tasks.

    """
    print("ACTION*:   " + str(action))
    print("ARGS*:     " + str(args))

    #=====================================================================================

    if 'trials' in action:
        try:
            trials_per_condition = int(args[0])
        except:
            trials_per_condition = 500

        model = config['model']
        pg = model.get_pg(config['savefile'], config['seed'], config['dt'])

        # Conditions
        spec = model.spec
        mods = spec.mods
        freqs = spec.freqs
        n_conditions = spec.n_conditions
        n_trials = n_conditions * trials_per_condition

        print("{} trials".format(n_trials))
        task = model.Task()
        trials = []
        for n in xrange(n_trials):
            k = tasktools.unravel_index(n, (len(mods), len(freqs)))
            context = {'mod': mods[k.pop(0)], 'freq': freqs[k.pop(0)]}
            trials.append(task.get_condition(pg.rng, pg.dt, context))
        runtools.run(action, trials, pg, config['trialspath'])

    #=====================================================================================

    elif action == 'psychometric':
        trialsfile = runtools.behaviorfile(config['trialspath'])

        fig = Figure()
        plot = fig.add()

        psychometric(trialsfile, plot)

        plot.vline(config['model'].spec.boundary)

        fig.save(path=config['figspath'], name='psychometric')
        fig.close()

    #=====================================================================================

    elif action == 'sort':
        if 'value' in args:
            network = 'v'
        else:
            network = 'p'

        trialsfile = runtools.activityfile(config['trialspath'])
        sort(trialsfile, (config['figspath'], 'sorted'), network=network)
示例#2
0
def do(action, args, config):
    """
    Manage tasks.

    """
    print("ACTION*:   " + str(action))
    print("ARGS*:     " + str(args))

    #=====================================================================================

    if 'trials' in action:
        try:
            trials_per_condition = int(args[0])
        except:
            trials_per_condition = 500

        model = config['model']
        pg    = model.get_pg(config['savefile'], config['seed'], config['dt'])

        # Conditions
        spec         = model.spec
        mods         = spec.mods
        freqs        = spec.freqs
        n_conditions = spec.n_conditions
        n_trials     = n_conditions * trials_per_condition

        print("{} trials".format(n_trials))
        task   = model.Task()
        trials = []
        for n in xrange(n_trials):
            k       = tasktools.unravel_index(n, (len(mods), len(freqs)))
            context = {'mod': mods[k.pop(0)], 'freq': freqs[k.pop(0)]}
            trials.append(task.get_condition(pg.rng, pg.dt, context))
        runtools.run(action, trials, pg, config['trialspath'])

    #=====================================================================================

    elif action == 'psychometric':
        trialsfile = runtools.behaviorfile(config['trialspath'])

        fig  = Figure()
        plot = fig.add()

        psychometric(trialsfile, plot)

        plot.vline(config['model'].spec.boundary)

        fig.save(path=config['figspath'], name='psychometric')
        fig.close()

    #=====================================================================================

    elif action == 'sort':
        if 'value' in args:
            network = 'v'
        else:
            network = 'p'

        trialsfile = runtools.activityfile(config['trialspath'])
        sort(trialsfile, (config['figspath'], 'sorted'), network=network)
示例#3
0
def sort(trialsfile, plots, units=None, network='p', **kwargs):
    """
    Sort trials.

    """
    # Load trials
    data = utils.load(trialsfile)
    trials, U, Z, Z_b, A, P, M, perf, r_p, r_v = data

    # Which network?
    if network == 'p':
        r = r_p
    else:
        r = r_v

    # Number of units
    N = r.shape[-1]

    # Same for every trial
    time = trials[0]['time']
    Ntime = len(time)

    # Aligned time
    time_a = np.concatenate((-time[1:][::-1], time))
    Ntime_a = len(time_a)

    #=====================================================================================
    # Aligned to stimulus onset
    #=====================================================================================

    r_by_cond_stimulus = {}
    n_r_by_cond_stimulus = {}
    for n, trial in enumerate(trials):
        if not perf.decisions[n]:
            continue

        if trial['mod'] == 'va':
            continue
        assert trial['mod'] == 'v' or trial['mod'] == 'a'

        if not perf.corrects[n]:
            continue

        # Condition
        mod = trial['mod']
        choice = perf.choices[n]
        cond = (mod, choice)

        # Storage
        r_by_cond_stimulus.setdefault(cond, np.zeros((Ntime_a, N)))
        n_r_by_cond_stimulus.setdefault(cond, np.zeros((Ntime_a, N)))

        # Firing rates
        Mn = np.tile(M[:, n], (N, 1)).T
        Rn = r[:, n] * Mn

        # Align point
        t0 = trial['epochs']['stimulus'][0] - 1

        # Before
        n_b = Rn[:t0].shape[0]
        r_by_cond_stimulus[cond][Ntime - 1 - n_b:Ntime - 1] += Rn[:t0]
        n_r_by_cond_stimulus[cond][Ntime - 1 - n_b:Ntime - 1] += Mn[:t0]

        # After
        n_a = Rn[t0:].shape[0]
        r_by_cond_stimulus[cond][Ntime - 1:Ntime - 1 + n_a] += Rn[t0:]
        n_r_by_cond_stimulus[cond][Ntime - 1:Ntime - 1 + n_a] += Mn[t0:]

    for cond in r_by_cond_stimulus:
        r_by_cond_stimulus[cond] = utils.div(r_by_cond_stimulus[cond],
                                             n_r_by_cond_stimulus[cond])

    #-------------------------------------------------------------------------------------
    # Plot
    #-------------------------------------------------------------------------------------

    lw = kwargs.get('lw', 1.5)
    dashes = kwargs.get('dashes', [3, 2])

    vline_props = {'lw': kwargs.get('lw_vline', 0.5)}
    if 'dashes_vline' in kwargs:
        vline_props['linestyle'] = '--'
        vline_props['dashes'] = dashes

    colors_by_mod = {'v': Figure.colors('blue'), 'a': Figure.colors('green')}
    linestyle_by_choice = {'L': '-', 'H': '--'}
    lineprops = dict(lw=lw)

    def plot_sorted(plot, unit, w, r_sorted):
        t = time_a[w]
        yall = [[1]]
        for cond in [('v', 'H'), ('v', 'L'), ('a', 'H'), ('a', 'L')]:
            mod, choice = cond

            if mod == 'v':
                label = 'Vis, '
            elif mod == 'a':
                label = 'Aud, '
            else:
                raise ValueError(mod)

            if choice == 'H':
                label += 'high'
            elif choice == 'L':
                label += 'low'
            else:
                raise ValueError(choice)

            linestyle = linestyle_by_choice[choice]
            if linestyle == '-':
                lineprops = dict(linestyle=linestyle, lw=lw)
            else:
                lineprops = dict(linestyle=linestyle, lw=lw, dashes=dashes)
            plot.plot(t,
                      r_sorted[cond][w, unit],
                      color=colors_by_mod[mod],
                      label=label,
                      **lineprops)
            yall.append(r_sorted[cond][w, unit])

        return t, yall

    def on_stimulus(plot, unit):
        w, = np.where((time_a >= -300) & (time_a <= 1000))
        t, yall = plot_sorted(plot, unit, w, r_by_cond_stimulus)

        plot.xlim(t[0], t[-1])

        return yall

    if units is not None:
        for plot, unit in zip(plots, units):
            on_stimulus(plot, unit)
    else:
        figspath, name = plots
        for unit in xrange(N):
            fig = Figure()
            plot = fig.add()

            #-----------------------------------------------------------------------------

            yall = []
            yall += on_stimulus(plot, unit)

            plot.lim('y', yall, lower=0)
            plot.vline(0)

            plot.xlabel('Time (ms)')
            plot.ylabel('Firing rate (a.u.)')

            #-----------------------------------------------------------------------------

            fig.save(path=figspath,
                     name=name + '_{}{:03d}'.format(network, unit))
            fig.close()
示例#4
0
def do(action, args, config):
    print("ACTION*:   " + str(action))
    print("ARGS*:     " + str(args))

    #=====================================================================================

    if 'trials' in action:
        try:
            trials_per_condition = int(args[0])
        except:
            trials_per_condition = 100

        model = config['model']
        pg    = model.get_pg(config['savefile'], config['seed'], config['dt'])

        spec         = model.spec
        wagers       = spec.wagers
        left_rights  = spec.left_rights
        cohs         = spec.cohs
        n_conditions = spec.n_conditions
        n_trials     = trials_per_condition * n_conditions

        print("{} trials".format(n_trials))
        task   = model.Task()
        trials = []
        for n in xrange(n_trials):
            k = tasktools.unravel_index(n, (len(wagers), len(left_rights), len(cohs)))
            context = {
                'wager':      wagers[k.pop(0)],
                'left_right': left_rights[k.pop(0)],
                'coh':        cohs[k.pop(0)]
                }
            trials.append(task.get_condition(pg.rng, pg.dt, context))
        runtools.run(action, trials, pg, config['trialspath'])

    #=====================================================================================

    elif action == 'sure_stimulus_duration':
        trialsfile = runtools.behaviorfile(config['trialspath'])

        fig  = Figure()
        plot = fig.add()

        sure_stimulus_duration(trialsfile, plot)

        plot.xlabel('Stimulus duration (ms)')
        plot.ylabel('Probability sure target')

        fig.save(path=config['figspath'], name=action)

    #=====================================================================================

    elif action == 'correct_stimulus_duration':
        trialsfile = runtools.behaviorfile(config['trialspath'])

        fig  = Figure()
        plot = fig.add()

        correct_stimulus_duration(trialsfile, plot)

        plot.xlabel('Stimulus duration (ms)')
        plot.ylabel('Probability correct')

        fig.save(path=config['figspath'], name=action)

    #=====================================================================================

    elif action == 'value_stimulus_duration':
        trialsfile = runtools.activityfile(config['trialspath'])

        fig  = Figure()
        plot = fig.add()

        value_stimulus_duration(trialsfile, plot)

        plot.xlabel('Stimulus duration (ms)')
        plot.ylabel('Expected reward')

        fig.save(path=config['figspath'], name=action)

    #=====================================================================================

    elif action == 'sort':
        if 'value' in args:
            network = 'v'
        else:
            network = 'p'

        trialsfile = runtools.activityfile(config['trialspath'])
        sort(trialsfile, os.path.join(config['figspath'], 'sorted'), network=network)
示例#5
0
def sort(trialsfile, plots, unit=None, network='p', **kwargs):
    # Load trials
    data = utils.load(trialsfile)
    if len(data) == 9:
        trials, U, Z, A, P, M, perf, r_p, r_v = data
    else:
        trials, U, Z, Z_b, A, P, M, perf, r_p, r_v = data

    if network == 'p':
        print("Sorting policy network activity.")
        r = r_p
    else:
        print("Sorting value network activity.")
        r = r_v

    # Number of units
    N = r.shape[-1]

    # Time
    time = trials[0]['time']
    Ntime = len(time)

    # Aligned time
    time_a  = np.concatenate((-time[1:][::-1], time))
    Ntime_a = len(time_a)

    #=====================================================================================
    # Preferred targets
    #=====================================================================================

    preferred_targets = get_preferred_targets(trials, perf, r)

    #=====================================================================================
    # No-wager trials
    #=====================================================================================

    def get_no_wager(func_t0):
        trials_by_cond = {}
        for n, trial in enumerate(trials):
            if trial['wager']:
                continue

            if trial['coh'] == 0:
                continue

            if perf.choices[n] is None:
                continue

            cond = trial['left_right']

            m_n = np.tile(M[:,n], (N, 1)).T
            r_n = r[:,n]*m_n

            t0 = func_t0(trial['epochs'], perf.t_choices[n])

            # Storage
            trials_by_cond.setdefault(cond, {'r': np.zeros((Ntime_a, N)),
                                             'n': np.zeros((Ntime_a, N))})

            # Before
            n_b = r_n[:t0].shape[0]
            trials_by_cond[cond]['r'][Ntime-1-n_b:Ntime-1] += r_n[:t0]
            trials_by_cond[cond]['n'][Ntime-1-n_b:Ntime-1] += m_n[:t0]

            # After
            n_a = r_n[t0:].shape[0]
            trials_by_cond[cond]['r'][Ntime-1:Ntime-1+n_a] += r_n[t0:]
            trials_by_cond[cond]['n'][Ntime-1:Ntime-1+n_a] += m_n[t0:]

        # Average
        for cond in trials_by_cond:
            trials_by_cond[cond] = utils.div(trials_by_cond[cond]['r'],
                                             trials_by_cond[cond]['n'])

        return trials_by_cond

    noTs_stimulus = get_no_wager(lambda epochs, t_choice: epochs['stimulus'][0] - 1)
    noTs_choice   = get_no_wager(lambda epochs, t_choice: t_choice)

    #=====================================================================================
    # Wager trials, aligned to stimulus onset
    #=====================================================================================

    def get_wager(func_t0):
        trials_by_cond      = {}
        trials_by_cond_sure = {}
        for n, trial in enumerate(trials):
            if not trial['wager']:
                continue

            if perf.choices[n] is None:
                continue

            if trial['coh'] == 0:
                continue

            cond = trial['left_right']

            m_n = np.tile(M[:,n], (N, 1)).T
            r_n = r[:,n]*m_n

            t0 = func_t0(trial['epochs'], perf.t_choices[n])

            if perf.choices[n] == 'S':
                # Storage
                trials_by_cond_sure.setdefault(cond, {'r': np.zeros((Ntime_a, N)),
                                                      'n': np.zeros((Ntime_a, N))})

                # Before
                n_b = r_n[:t0].shape[0]
                trials_by_cond_sure[cond]['r'][Ntime-1-n_b:Ntime-1] += r_n[:t0]
                trials_by_cond_sure[cond]['n'][Ntime-1-n_b:Ntime-1] += m_n[:t0]

                # After
                n_a = r_n[t0:].shape[0]
                trials_by_cond_sure[cond]['r'][Ntime-1:Ntime-1+n_a] += r_n[t0:]
                trials_by_cond_sure[cond]['n'][Ntime-1:Ntime-1+n_a] += m_n[t0:]
            else:
                # Storage
                trials_by_cond.setdefault(cond, {'r': np.zeros((Ntime_a, N)),
                                                 'n': np.zeros((Ntime_a, N))})

                # Before
                n_b = r_n[:t0].shape[0]
                trials_by_cond[cond]['r'][Ntime-1-n_b:Ntime-1] += r_n[:t0]
                trials_by_cond[cond]['n'][Ntime-1-n_b:Ntime-1] += m_n[:t0]

                # After
                n_a = r_n[t0:].shape[0]
                trials_by_cond[cond]['r'][Ntime-1:Ntime-1+n_a] += r_n[t0:]
                trials_by_cond[cond]['n'][Ntime-1:Ntime-1+n_a] += m_n[t0:]

        # Average
        for cond in trials_by_cond:
            trials_by_cond[cond] = utils.div(trials_by_cond[cond]['r'],
                                             trials_by_cond[cond]['n'])

        # Average
        for cond in trials_by_cond_sure:
            trials_by_cond_sure[cond] = utils.div(trials_by_cond_sure[cond]['r'],
                                                  trials_by_cond_sure[cond]['n'])

        return trials_by_cond, trials_by_cond_sure

    Ts_stimulus, Ts_stimulus_sure = get_wager(lambda epochs, t_choice: epochs['stimulus'][0] - 1)
    Ts_sure, Ts_sure_sure         = get_wager(lambda epochs, t_choice: epochs['sure'][0] - 1)
    Ts_choice, Ts_choice_sure     = get_wager(lambda epochs, t_choice: t_choice)

    #=====================================================================================
    # Plot
    #=====================================================================================

    lw     = kwargs.get('lw', 1.25)
    dashes = kwargs.get('dashes', [3, 1.5])

    in_opp_colors = {-1: '0.6', +1: 'k'}

    def plot_noTs(noTs, plot, unit, tmin, tmax):
        w,   = np.where((tmin <= time_a) & (time_a <= tmax))
        t    = time_a[w]
        yall = [[1]]

        for lr in noTs:
            color = in_opp_colors[lr*preferred_targets[unit]]
            y = noTs[lr][w,unit]
            plot.plot(t, y, color=color, lw=lw)
            yall.append(y)

        plot.xlim(tmin, tmax)
        plot.xticks([0, tmax])
        plot.lim('y', yall, lower=0)

        return yall

    def plot_Ts(Ts, Ts_sure, plot, unit, tmin, tmax):
        w,   = np.where((tmin <= time_a) & (time_a <= tmax))
        t    = time_a[w]
        yall = [[1]]

        for lr in Ts:
            color = in_opp_colors[lr*preferred_targets[unit]]
            y = Ts[lr][w,unit]
            plot.plot(t, y, color=color, lw=lw)
            yall.append(y)
        for lr in Ts_sure:
            color = in_opp_colors[lr*preferred_targets[unit]]
            y = Ts_sure[lr][w,unit]
            plot.plot(t, y, color=color, lw=lw, linestyle='--', dashes=dashes)
            yall.append(y)

        plot.xlim(tmin, tmax)
        plot.xticks([0, tmax])
        plot.lim('y', yall, lower=0)

        return yall

    if unit is not None:
        y = []

        tmin = kwargs.get('noTs-stimulus-tmin', -100)
        tmax = kwargs.get('noTs-stimulus-tmax', 700)
        y += plot_noTs(noTs_stimulus, plots['noTs-stimulus'], unit, tmin, tmax)

        tmin = kwargs.get('noTs-choice-tmin', -500)
        tmax = kwargs.get('noTs-choice-tmax', 0)
        y += plot_noTs(noTs_choice, plots['noTs-choice'], unit, tmin, tmax)

        tmin = kwargs.get('Ts-stimulus-tmin', -100)
        tmax = kwargs.get('Ts-stimulus-tmax', 700)
        y += plot_Ts(Ts_stimulus, Ts_stimulus_sure, plots['Ts-stimulus'], unit, tmin, tmax)

        tmin = kwargs.get('Ts-sure-tmin', -200)
        tmax = kwargs.get('Ts-sure-tmax', 700)
        y += plot_Ts(Ts_sure, Ts_sure_sure, plots['Ts-sure'], unit, tmin, tmax)

        tmin = kwargs.get('Ts-choice-tmin', -500)
        tmax = kwargs.get('Ts-choice-tmax', 0)
        y += plot_Ts(Ts_choice, Ts_choice_sure, plots['Ts-choice'], unit, tmin, tmax)

        return y
    else:
        name = plots
        for unit in xrange(N):
            w   = utils.mm_to_inch(174)
            r   = 0.35
            fig = Figure(w=w, r=r)

            x0 = 0.09
            y0 = 0.15

            w = 0.13
            h = 0.75

            dx = 0.05
            DX = 0.08

            fig.add('noTs-stimulus', [x0, y0, w, h])
            fig.add('noTs-choice',   [fig[-1].right+dx, y0, w, h])
            fig.add('Ts-stimulus',   [fig[-1].right+DX, y0, w, h])
            fig.add('Ts-sure',       [fig[-1].right+dx, y0, w, h])
            fig.add('Ts-choice',     [fig[-1].right+dx, y0, w, h])

            #-----------------------------------------------------------------------------

            y = []

            plot = fig['noTs-stimulus']
            y += plot_noTs(noTs_stimulus, plot, unit, -100, 700)
            plot.vline(0)

            plot = fig['noTs-choice']
            y += plot_noTs(noTs_choice, plot, unit, -500, 200)
            plot.vline(0)

            plot = fig['Ts-stimulus']
            y += plot_Ts(Ts_stimulus, Ts_stimulus_sure, plot, unit, -100, 700)
            plot.vline(0)

            plot = fig['Ts-sure']
            y += plot_Ts(Ts_sure, Ts_sure_sure, plot, unit, -200, 700)
            plot.vline(0)

            plot = fig['Ts-choice']
            y += plot_Ts(Ts_choice, Ts_choice_sure, plot, unit, -500, 200)
            plot.vline(0)

            for plot in fig.plots.values():
                plot.lim('y', y, lower=0)

            #-----------------------------------------------------------------------------

            fig.save(name+'_{}{:03d}'.format(network, unit))
            fig.close()
示例#6
0
文件: romo.py 项目: sumwor/pylearning
def sort(trialsfile, plots, units=None, network='p', **kwargs):
    """
    Sort trials.

    """
    # Load trials
    data = utils.load(trialsfile)
    if len(data) == 9:
        trials, U, Z, A, P, M, perf, r_p, r_v = data
    else:
        trials, U, Z, Z_b, A, P, M, perf, r_p, r_v = data

    # Which network?
    if network == 'p':
        r = r_p
    else:
        r = r_v

    # Data shape
    Ntime = r.shape[0]
    N = r.shape[-1]

    # Same for every trial
    time = trials[0]['time']

    # Aligned time
    time_a = np.concatenate((-time[1:][::-1], time))
    Ntime_a = len(time_a)

    #=====================================================================================
    # Sort trials
    #=====================================================================================

    # Sort
    trials_by_cond = {}
    for n, trial in enumerate(trials):
        if perf.choices[n] is None or not perf.corrects[n]:
            continue

        # Condition
        gt_lt = trial['gt_lt']
        fpair = trial['fpair']
        if gt_lt == '>':
            f1, f2 = fpair
        else:
            f2, f1 = fpair
        cond = (f1, f2)

        # Firing rates
        Mn = np.tile(M[:, n], (N, 1)).T
        Rn = r[:, n] * Mn

        # Align point
        t0 = trial['epochs']['f1'][0] - 1

        # Storage
        trials_by_cond.setdefault(cond, {
            'r': np.zeros((Ntime_a, N)),
            'n': np.zeros((Ntime_a, N))
        })

        # Before
        n_b = Rn[:t0].shape[0]
        trials_by_cond[cond]['r'][Ntime - 1 - n_b:Ntime - 1] += Rn[:t0]
        trials_by_cond[cond]['n'][Ntime - 1 - n_b:Ntime - 1] += Mn[:t0]

        # After
        n_a = Rn[t0:].shape[0]
        trials_by_cond[cond]['r'][Ntime - 1:Ntime - 1 + n_a] += Rn[t0:]
        trials_by_cond[cond]['n'][Ntime - 1:Ntime - 1 + n_a] += Mn[t0:]

    # Average
    for cond in trials_by_cond:
        trials_by_cond[cond] = utils.div(trials_by_cond[cond]['r'],
                                         trials_by_cond[cond]['n'])

    #=====================================================================================
    # Plot
    #=====================================================================================

    lw = kwargs.get('lw', 1.5)

    w, = np.where((time_a >= -500) & (time_a <= 4000))

    def plot_sorted(plot, unit):
        t = 1e-3 * time_a[w]
        yall = [[1]]
        for (f1, f2), r in trials_by_cond.items():
            plot.plot(t, r[w, unit], color=smap.to_rgba(f1), lw=lw)
            yall.append(r[w, unit])

        return t, yall

    if units is not None:
        for plot, unit in zip(plots, units):
            plot_sorted(plot, unit)
    else:
        figspath, name = plots
        for unit in xrange(N):
            fig = Figure()
            plot = fig.add()

            #-----------------------------------------------------------------------------

            t, yall = plot_sorted(plot, unit)

            plot.xlim(t[0], t[-1])
            plot.lim('y', yall, lower=0)

            plot.highlight(0, 0.5)
            plot.highlight(3.5, 4)

            #-----------------------------------------------------------------------------

            fig.save(path=figspath,
                     name=name + '_{}{:03d}'.format(network, unit))
            fig.close()
示例#7
0
def sort(trialsfile, plots, unit=None, network='p', **kwargs):
    # Load trials
    data = utils.load(trialsfile)
    if len(data) == 9:
        trials, U, Z, A, P, M, perf, r_p, r_v = data
    else:
        trials, U, Z, Z_b, A, P, M, perf, r_p, r_v = data

    if network == 'p':
        print("Sorting policy network activity.")
        r = r_p
    else:
        print("Sorting value network activity.")
        r = r_v

    # Number of units
    N = r.shape[-1]

    # Time
    time = trials[0]['time']
    Ntime = len(time)

    # Aligned time
    time_a  = np.concatenate((-time[1:][::-1], time))
    Ntime_a = len(time_a)

    #=====================================================================================
    # Preferred targets
    #=====================================================================================

    preferred_targets = get_preferred_targets(trials, perf, r)

    #=====================================================================================
    # No-wager trials
    #=====================================================================================

    def get_no_wager(func_t0):
        trials_by_cond = {}
        for n, trial in enumerate(trials):
            if trial['wager']:
                continue

            if trial['coh'] == 0:
                continue

            if perf.choices[n] is None:
                continue

            cond = trial['left_right']

            m_n = np.tile(M[:,n], (N, 1)).T
            r_n = r[:,n]*m_n

            t0 = func_t0(trial['epochs'], perf.t_choices[n])

            # Storage
            trials_by_cond.setdefault(cond, {'r': np.zeros((Ntime_a, N)),
                                             'n': np.zeros((Ntime_a, N))})

            # Before
            n_b = r_n[:t0].shape[0]
            trials_by_cond[cond]['r'][Ntime-1-n_b:Ntime-1] += r_n[:t0]
            trials_by_cond[cond]['n'][Ntime-1-n_b:Ntime-1] += m_n[:t0]

            # After
            n_a = r_n[t0:].shape[0]
            trials_by_cond[cond]['r'][Ntime-1:Ntime-1+n_a] += r_n[t0:]
            trials_by_cond[cond]['n'][Ntime-1:Ntime-1+n_a] += m_n[t0:]

        # Average
        for cond in trials_by_cond:
            trials_by_cond[cond] = utils.div(trials_by_cond[cond]['r'],
                                             trials_by_cond[cond]['n'])

        return trials_by_cond

    noTs_stimulus = get_no_wager(lambda epochs, t_choice: epochs['stimulus'][0] - 1)
    noTs_choice   = get_no_wager(lambda epochs, t_choice: t_choice)

    #=====================================================================================
    # Wager trials, aligned to stimulus onset
    #=====================================================================================

    def get_wager(func_t0):
        trials_by_cond      = {}
        trials_by_cond_sure = {}
        for n, trial in enumerate(trials):
            if not trial['wager']:
                continue

            if perf.choices[n] is None:
                continue

            if trial['coh'] == 0:
                continue

            cond = trial['left_right']

            m_n = np.tile(M[:,n], (N, 1)).T
            r_n = r[:,n]*m_n

            t0 = func_t0(trial['epochs'], perf.t_choices[n])

            if perf.choices[n] == 'S':
                # Storage
                trials_by_cond_sure.setdefault(cond, {'r': np.zeros((Ntime_a, N)),
                                                      'n': np.zeros((Ntime_a, N))})

                # Before
                n_b = r_n[:t0].shape[0]
                trials_by_cond_sure[cond]['r'][Ntime-1-n_b:Ntime-1] += r_n[:t0]
                trials_by_cond_sure[cond]['n'][Ntime-1-n_b:Ntime-1] += m_n[:t0]

                # After
                n_a = r_n[t0:].shape[0]
                trials_by_cond_sure[cond]['r'][Ntime-1:Ntime-1+n_a] += r_n[t0:]
                trials_by_cond_sure[cond]['n'][Ntime-1:Ntime-1+n_a] += m_n[t0:]
            else:
                # Storage
                trials_by_cond.setdefault(cond, {'r': np.zeros((Ntime_a, N)),
                                                 'n': np.zeros((Ntime_a, N))})

                # Before
                n_b = r_n[:t0].shape[0]
                trials_by_cond[cond]['r'][Ntime-1-n_b:Ntime-1] += r_n[:t0]
                trials_by_cond[cond]['n'][Ntime-1-n_b:Ntime-1] += m_n[:t0]

                # After
                n_a = r_n[t0:].shape[0]
                trials_by_cond[cond]['r'][Ntime-1:Ntime-1+n_a] += r_n[t0:]
                trials_by_cond[cond]['n'][Ntime-1:Ntime-1+n_a] += m_n[t0:]

        # Average
        for cond in trials_by_cond:
            trials_by_cond[cond] = utils.div(trials_by_cond[cond]['r'],
                                             trials_by_cond[cond]['n'])

        # Average
        for cond in trials_by_cond_sure:
            trials_by_cond_sure[cond] = utils.div(trials_by_cond_sure[cond]['r'],
                                                  trials_by_cond_sure[cond]['n'])

        return trials_by_cond, trials_by_cond_sure

    Ts_stimulus, Ts_stimulus_sure = get_wager(lambda epochs, t_choice: epochs['stimulus'][0] - 1)
    Ts_sure, Ts_sure_sure         = get_wager(lambda epochs, t_choice: epochs['sure'][0] - 1)
    Ts_choice, Ts_choice_sure     = get_wager(lambda epochs, t_choice: t_choice)

    #=====================================================================================
    # Plot
    #=====================================================================================

    lw     = kwargs.get('lw', 1.25)
    dashes = kwargs.get('dashes', [3, 1.5])

    in_opp_colors = {-1: '0.6', +1: 'k'}

    def plot_noTs(noTs, plot, unit, tmin, tmax):
        w,   = np.where((tmin <= time_a) & (time_a <= tmax))
        t    = time_a[w]
        yall = [[1]]

        for lr in noTs:
            color = in_opp_colors[lr*preferred_targets[unit]]
            y = noTs[lr][w,unit]
            plot.plot(t, y, color=color, lw=lw)
            yall.append(y)

        plot.xlim(tmin, tmax)
        plot.xticks([0, tmax])
        plot.lim('y', yall, lower=0)

        return yall

    def plot_Ts(Ts, Ts_sure, plot, unit, tmin, tmax):
        w,   = np.where((tmin <= time_a) & (time_a <= tmax))
        t    = time_a[w]
        yall = [[1]]

        for lr in Ts:
            color = in_opp_colors[lr*preferred_targets[unit]]
            y = Ts[lr][w,unit]
            plot.plot(t, y, color=color, lw=lw)
            yall.append(y)
        for lr in Ts_sure:
            color = in_opp_colors[lr*preferred_targets[unit]]
            y = Ts_sure[lr][w,unit]
            plot.plot(t, y, color=color, lw=lw, linestyle='--', dashes=dashes)
            yall.append(y)

        plot.xlim(tmin, tmax)
        plot.xticks([0, tmax])
        plot.lim('y', yall, lower=0)

        return yall

    if unit is not None:
        y = []

        tmin = kwargs.get('noTs-stimulus-tmin', -100)
        tmax = kwargs.get('noTs-stimulus-tmax', 700)
        y += plot_noTs(noTs_stimulus, plots['noTs-stimulus'], unit, tmin, tmax)

        tmin = kwargs.get('noTs-choice-tmin', -500)
        tmax = kwargs.get('noTs-choice-tmax', 0)
        y += plot_noTs(noTs_choice, plots['noTs-choice'], unit, tmin, tmax)

        tmin = kwargs.get('Ts-stimulus-tmin', -100)
        tmax = kwargs.get('Ts-stimulus-tmax', 700)
        y += plot_Ts(Ts_stimulus, Ts_stimulus_sure, plots['Ts-stimulus'], unit, tmin, tmax)

        tmin = kwargs.get('Ts-sure-tmin', -200)
        tmax = kwargs.get('Ts-sure-tmax', 700)
        y += plot_Ts(Ts_sure, Ts_sure_sure, plots['Ts-sure'], unit, tmin, tmax)

        tmin = kwargs.get('Ts-choice-tmin', -500)
        tmax = kwargs.get('Ts-choice-tmax', 0)
        y += plot_Ts(Ts_choice, Ts_choice_sure, plots['Ts-choice'], unit, tmin, tmax)

        return y
    else:
        name = plots
        for unit in xrange(N):
            w   = utils.mm_to_inch(174)
            r   = 0.35
            fig = Figure(w=w, r=r)

            x0 = 0.09
            y0 = 0.15

            w = 0.13
            h = 0.75

            dx = 0.05
            DX = 0.08

            fig.add('noTs-stimulus', [x0, y0, w, h])
            fig.add('noTs-choice',   [fig[-1].right+dx, y0, w, h])
            fig.add('Ts-stimulus',   [fig[-1].right+DX, y0, w, h])
            fig.add('Ts-sure',       [fig[-1].right+dx, y0, w, h])
            fig.add('Ts-choice',     [fig[-1].right+dx, y0, w, h])

            #-----------------------------------------------------------------------------

            y = []

            plot = fig['noTs-stimulus']
            y += plot_noTs(noTs_stimulus, plot, unit, -100, 700)
            plot.vline(0)

            plot = fig['noTs-choice']
            y += plot_noTs(noTs_choice, plot, unit, -500, 200)
            plot.vline(0)

            plot = fig['Ts-stimulus']
            y += plot_Ts(Ts_stimulus, Ts_stimulus_sure, plot, unit, -100, 700)
            plot.vline(0)

            plot = fig['Ts-sure']
            y += plot_Ts(Ts_sure, Ts_sure_sure, plot, unit, -200, 700)
            plot.vline(0)

            plot = fig['Ts-choice']
            y += plot_Ts(Ts_choice, Ts_choice_sure, plot, unit, -500, 200)
            plot.vline(0)

            for plot in fig.plots.values():
                plot.lim('y', y, lower=0)

            #-----------------------------------------------------------------------------

            fig.save(name+'_{}{:03d}'.format(network, unit))
            fig.close()
示例#8
0
kwargs = {'on-stimulus-tmin': -200, 'on-stimulus-tmax': 600,
          'on-choice-tmin': -400, 'on-choice-tmax': 0,
          'colors': 'kiani', 'dashes': [3.5, 2]}
rdm_analysis.sort_return(rdm_fixed_activity, fig.plots, **kwargs)

plot = fig['on-stimulus']
plot.xlim(-200, 600)
plot.xticks([-200, 0, 200, 400, 600])
plot.ylim(0.5, 1.2)
#plot.yticks([0.5, 1])

plot.xlabel('Time from stimulus (ms)')
plot.ylabel('Expected reward')

# Legend
props = {'prop': {'size': 8}, 'handlelength': 1.2,
         'handletextpad': 1.1, 'labelspacing': 0.7}
plot.legend(bbox_to_anchor=(0.33, 1), **props)

plot = fig['on-choice']
plot.xlim(-400, 0)
plot.xticks([-400, -200, 0])
plot.ylim(0.5, 1.2)
#plot.yticks([0.5, 1])

plot.xlabel('Time from decision (ms)')

#=========================================================================================

fig.save()
示例#9
0
def sort_epoch(behaviorfile,
               activityfile,
               epoch,
               offers,
               plots,
               units=None,
               network='p',
               separate_by_choice=False,
               **kwargs):
    """
    Sort trials.

    """
    # Load trials
    data = utils.load(activityfile)
    trials, U, Z, Z_b, A, P, M, perf, r_p, r_v = data

    if network == 'p':
        print("POLICY NETWORK")
        r = r_p
    else:
        print("VALUE NETWORK")
        r = r_v

    # Number of units
    N = r.shape[-1]

    # Same for every trial
    time = trials[0]['time']
    Ntime = len(time)

    # Aligned time
    time_a = np.concatenate((-time[1:][::-1], time))
    Ntime_a = len(time_a)

    #=====================================================================================
    # Sort trials
    #=====================================================================================

    # Epochs
    events = ['offer', 'choice']

    # Sort
    events_by_cond = {e: {} for e in events}
    n_by_cond = {}
    n_nondecision = 0
    for n, trial in enumerate(trials):
        if perf.choices[n] is None:
            n_nondecision += 1
            continue

        # Condition
        offer = trial['offer']
        choice = perf.choices[n]

        if separate_by_choice:
            cond = (offer, choice)
        else:
            cond = offer

        n_by_cond.setdefault(cond, 0)
        n_by_cond[cond] += 1

        # Storage
        for e in events_by_cond:
            events_by_cond[e].setdefault(cond, {
                'r': np.zeros((Ntime_a, N)),
                'n': np.zeros((Ntime_a, N))
            })

        # Firing rates
        m_n = np.tile(M[:, n], (N, 1)).T
        r_n = r[:, n] * m_n

        for e in events_by_cond:
            # Align point
            if e == 'offer':
                t0 = trial['epochs']['offer-on'][0]
            elif e == 'choice':
                t0 = perf.t_choices[n]
            else:
                raise ValueError(e)

            # Before
            n_b = r_n[:t0].shape[0]
            events_by_cond[e][cond]['r'][Ntime - 1 - n_b:Ntime - 1] += r_n[:t0]
            events_by_cond[e][cond]['n'][Ntime - 1 - n_b:Ntime - 1] += m_n[:t0]

            # After
            n_a = r_n[t0:].shape[0]
            events_by_cond[e][cond]['r'][Ntime - 1:Ntime - 1 + n_a] += r_n[t0:]
            events_by_cond[e][cond]['n'][Ntime - 1:Ntime - 1 + n_a] += m_n[t0:]
    print("Non-decision trials: {}/{}".format(n_nondecision, len(trials)))

    # Average trials
    for e in events_by_cond:
        for cond in events_by_cond[e]:
            events_by_cond[e][cond] = utils.div(events_by_cond[e][cond]['r'],
                                                events_by_cond[e][cond]['n'])

    # Epochs
    epochs = ['preoffer', 'postoffer', 'latedelay', 'prechoice']

    # Average epochs
    epochs_by_cond = {e: {} for e in epochs}
    for e in epochs_by_cond:
        if e == 'preoffer':
            ev = 'offer'
            w, = np.where((-500 <= time_a) & (time_a < 0))
        elif e == 'postoffer':
            ev = 'offer'
            w, = np.where((0 <= time_a) & (time_a < 500))
        elif e == 'latedelay':
            ev = 'offer'
            w, = np.where((500 <= time_a) & (time_a < 1000))
        elif e == 'prechoice':
            ev = 'choice'
            w, = np.where((-500 <= time_a) & (time_a < 0))
        else:
            raise ValueError(e)

        for cond in events_by_cond[ev]:
            epochs_by_cond[e][cond] = np.mean(events_by_cond[ev][cond][w],
                                              axis=0)

    #=====================================================================================
    # Classify units
    #=====================================================================================

    idpt = indifference_point(behaviorfile, offers)
    unit_types = classify_units(trials, perf, r, idpt)
    #unit_types = {}

    numbers = {}
    for v in unit_types.values():
        numbers[v] = 0
    for k, v in unit_types.items():
        numbers[v] += 1

    n_tot = np.sum(numbers.values())
    for k, v in numbers.items():
        print("{}: {}/{} = {}%".format(k, v, n_tot, 100 * v / n_tot))

    #=====================================================================================
    # Plot
    #=====================================================================================

    lw = kwargs.get('lw', 1.5)
    ms = kwargs.get('ms', 6)
    mew = kwargs.get('mew', 0.5)
    rotation = kwargs.get('rotation', 60)

    #min_trials = kwargs.get('min_trials', 100)

    def plot_activity(plot, unit):
        yall = [1]

        min_trials = 20

        # Pre-offer
        epoch_by_cond = epochs_by_cond['preoffer']
        color = '0.7'
        if separate_by_choice:
            for choice, marker in zip(['A', 'B'], ['d', 'o']):
                x = []
                y = []
                for i, offer in enumerate(offers):
                    cond = (offer, choice)
                    if cond in n_by_cond and n_by_cond[cond] >= min_trials:
                        y_i = epoch_by_cond[cond][unit]
                        plot.plot(i,
                                  y_i,
                                  marker,
                                  mfc=color,
                                  mec=color,
                                  ms=0.8 * ms,
                                  mew=0.8 * mew,
                                  zorder=10)
                        yall.append(y_i)
                        if i != 0 and i != len(offers) - 1:
                            x.append(i)
                            y.append(y_i)
                plot.plot(x, y, '-', color=color, lw=0.8 * lw, zorder=5)
        else:
            x = []
            y = []
            for i, offer in enumerate(offers):
                y_i = epoch_by_cond[offer][unit]
                plot.plot(i,
                          y_i,
                          'o',
                          mfc=color,
                          mec=color,
                          ms=0.8 * ms,
                          mew=0.8 * mew,
                          zorder=10)
                yall.append(y_i)
                if i != 0 and i != len(offers) - 1:
                    x.append(i)
                    y.append(y_i)
            plot.plot(x, y, '-', color=color, lw=0.8 * lw, zorder=5)

        # Epoch
        epoch_by_cond = epochs_by_cond[epoch]
        if epoch == 'postoffer':
            color = Figure.colors('darkblue')
        elif epoch == 'latedelay':
            color = Figure.colors('darkblue')
        elif epoch == 'prechoice':
            color = Figure.colors('darkblue')
        else:
            raise ValueError(epoch)
        if separate_by_choice:
            for choice, marker, color in zip(
                ['A', 'B'], ['d', 'o'],
                [Figure.colors('red'),
                 Figure.colors('blue')]):
                x = []
                y = []
                for i, offer in enumerate(offers):
                    cond = (offer, choice)
                    if cond in n_by_cond and n_by_cond[cond] >= min_trials:
                        y_i = epoch_by_cond[cond][unit]
                        yall.append(y_i)
                        plot.plot(i,
                                  y_i,
                                  marker,
                                  mfc=color,
                                  mec=color,
                                  ms=ms,
                                  mew=mew,
                                  zorder=10)
                        if i != 0 and i != len(offers) - 1:
                            x.append(i)
                            y.append(y_i)
                plot.plot(x, y, '-', color=color, lw=lw, zorder=5)
        else:
            x = []
            y = []
            for i, offer in enumerate(offers):
                y_i = epoch_by_cond[offer][unit]
                plot.plot(i,
                          y_i,
                          'o',
                          mfc=color,
                          mec=color,
                          ms=ms,
                          mew=mew,
                          zorder=10)
                yall.append(y_i)
                if i != 0 and i != len(offers) - 1:
                    x.append(i)
                    y.append(y_i)
            plot.plot(x, y, '-', color=color, lw=lw, zorder=5)

        plot.xticks(range(len(offers)))
        plot.xticklabels(['{}B:{}A'.format(*offer) for offer in offers],
                         rotation=rotation)

        plot.xlim(0, len(offers) - 1)
        plot.lim('y', yall, lower=0)

        return yall

    #-------------------------------------------------------------------------------------

    if units is not None:
        for plot, unit in zip(plots, units):
            plot_activity(plot, unit)
    else:
        name = plots
        for unit in xrange(N):
            fig = Figure()
            plot = fig.add()

            plot_activity(plot, unit)

            if separate_by_choice:
                suffix = '_sbc'
            else:
                suffix = ''

            if unit in unit_types:
                plot.text_upper_right(unit_types[unit], fontsize=9)

            fig.save(name +
                     '_{}{}_{}{:03d}'.format(epoch, suffix, network, unit))
            fig.close()
示例#10
0
rho = matrixtools.spectral_radius(W)
plot_weights(plot, W)
plot.text_upper_left(r'$W_\text{rec}$', fontsize=fontsize, dy=dy)
plot.text_upper_right(r'$\rho={:.3f}$'.format(rho), fontsize=fontsize, dy=dy)
plot.xlabel('$W$')

#for j in xrange(W.shape[1]):
#    print(sum(1*(W[:,j] != 0)))

Wrec_gates = params['Wrec_gates']
if 'Wrec_gates' in masks:
    Wrec_gates *= masks['Wrec_gates']

plot = fig['Wrec_lambda']
W = Wrec_gates[:,:N]
rho = matrixtools.spectral_radius(W)
plot_weights(plot, W)
plot.text_upper_left(r'$W_\text{rec}^\lambda$', fontsize=fontsize, dy=dy)
plot.text_upper_right(r'$\rho={:.3f}$'.format(rho), fontsize=fontsize, dy=dy)

plot = fig['Wrec_gamma']
W = Wrec_gates[:,N:]
rho = matrixtools.spectral_radius(W)
plot_weights(plot, W)
plot.text_upper_left(r'$W_\text{rec}^\gamma$', fontsize=fontsize, dy=dy)
plot.text_upper_right(r'$\rho={:.3f}$'.format(rho), fontsize=fontsize, dy=dy)
#'''
#=========================================================================================

fig.save(path=figspath, name='fig_weights_'+modelname)
示例#11
0
plot.ylabel('Percent correct\n(decision trials)')

target_color = Figure.colors('red')

plot = fig['correct']
if modelname.startswith('rdm_fixed'):
    target = 80
    plot.hline(target, color=target_color, zorder=1)
elif modelname == 'rdm_rt':
    target = 80
    plot.hline(target, color=target_color, zorder=1)
elif modelname == 'mante':
    target = 85
    plot.hline(target, color=target_color, zorder=1)
elif modelname == 'multisensory':
    target = 82
    plot.hline(target, color=target_color, zorder=1)
elif modelname == 'romo':
    target = 97
    plot.hline(target, color=target_color, zorder=1)
elif modelname == 'postdecisionwager':
    target = 79
    plot.hline(target, color=target_color, zorder=1)
elif modelname == 'padoaschioppa2006':
    target = 95
    plot.hline(target, color=target_color, zorder=1)

#=========================================================================================

fig.save(path=figspath, name='fig_learning_' + modelname)
示例#12
0
plot.ylabel('Percent correct\n(decision trials)')

target_color = Figure.colors('red')

plot = fig['correct']
if modelname.startswith('rdm_fixed'):
    target = 80
    plot.hline(target, color=target_color, zorder=1)
elif modelname == 'rdm_rt':
    target = 80
    plot.hline(target, color=target_color, zorder=1)
elif modelname == 'mante':
    target = 85
    plot.hline(target, color=target_color, zorder=1)
elif modelname == 'multisensory':
    target = 82
    plot.hline(target, color=target_color, zorder=1)
elif modelname == 'romo':
    target = 97
    plot.hline(target, color=target_color, zorder=1)
elif modelname == 'postdecisionwager':
    target = 79
    plot.hline(target, color=target_color, zorder=1)
elif modelname == 'padoaschioppa2006':
    target = 95
    plot.hline(target, color=target_color, zorder=1)

#=========================================================================================

fig.save(path=figspath, name='fig_learning_'+modelname)
示例#13
0
def do(action, args, config):
    """
    Manage tasks.

    """
    print("ACTION*:   " + str(action))
    print("ARGS*:     " + str(args))

    #=====================================================================================

    if 'trials' in action:
        try:
            trials_per_condition = int(args[0])
        except IndexError:
            trials_per_condition = 100

        model = config['model']
        pg    = model.get_pg(config['savefile'], config['seed'], config['dt'])

        spec         = model.spec
        juices       = spec.juices
        offers       = spec.offers
        n_conditions = spec.n_conditions
        n_trials     = trials_per_condition * n_conditions

        print("{} trials".format(n_trials))
        task   = model.Task()
        trials = []
        for n in xrange(n_trials):
            k = tasktools.unravel_index(n, (len(juices), len(offers)))
            context = {
                'juice': juices[k.pop(0)],
                'offer': offers[k.pop(0)]
                }
            trials.append(task.get_condition(pg.rng, pg.dt, context))
        runtools.run(action, trials, pg, config['trialspath'])

    #=====================================================================================

    elif action == 'choice_pattern':
        trialsfile = runtools.behaviorfile(config['trialspath'])

        fig  = Figure()
        plot = fig.add()

        spec = config['model'].spec

        choice_pattern(trialsfile, spec.offers, plot)

        plot.xlabel('Offer (\#B : \#A)')
        plot.ylabel('Percent choice B')

        plot.text_upper_left('1A = {}B'.format(spec.A_to_B), fontsize=10)

        fig.save(path=config['figspath'], name=action)
        fig.close()

    elif action == 'indifference_point':
        trialsfile = runtools.behaviorfile(config['trialspath'])

        fig  = Figure()
        plot = fig.add()

        spec = config['model'].spec

        indifference_point(trialsfile, spec.offers, plot)

        plot.xlabel('$(n_B - n_A)/(n_B + n_A)$')
        plot.ylabel('Percent choice B')

        #plot.text_upper_left('1A = {}B'.format(spec.A_to_B), fontsize=10)

        fig.save(path=config['figspath'], name=action)
        fig.close()

    #=====================================================================================

    elif action == 'sort_epoch':
        behaviorfile = runtools.behaviorfile(config['trialspath'])
        activityfile = runtools.activityfile(config['trialspath'])

        epoch = args[0]

        if 'value' in args:
            network = 'v'
        else:
            network = 'p'

        separate_by_choice = ('separate-by-choice' in args)

        sort_epoch(behaviorfile, activityfile, epoch, config['model'].spec.offers,
                   os.path.join(config['figspath'], 'sorted'),
                   network=network, separate_by_choice=separate_by_choice)
示例#14
0
def sort_epoch(behaviorfile, activityfile, epoch, offers, plots, units=None, network='p',
               separate_by_choice=False, **kwargs):
    """
    Sort trials.

    """
    # Load trials
    data = utils.load(activityfile)
    trials, U, Z, Z_b, A, P, M, perf, r_p, r_v = data

    if network == 'p':
        print("POLICY NETWORK")
        r = r_p
    else:
        print("VALUE NETWORK")
        r = r_v

    # Number of units
    N = r.shape[-1]

    # Same for every trial
    time  = trials[0]['time']
    Ntime = len(time)

    # Aligned time
    time_a  = np.concatenate((-time[1:][::-1], time))
    Ntime_a = len(time_a)

    #=====================================================================================
    # Sort trials
    #=====================================================================================

    # Epochs
    events = ['offer', 'choice']

    # Sort
    events_by_cond = {e: {} for e in events}
    n_by_cond      = {}
    n_nondecision  = 0
    for n, trial in enumerate(trials):
        if perf.choices[n] is None:
            n_nondecision += 1
            continue

        # Condition
        offer  = trial['offer']
        choice = perf.choices[n]

        if separate_by_choice:
            cond = (offer, choice)
        else:
            cond = offer

        n_by_cond.setdefault(cond, 0)
        n_by_cond[cond] += 1

        # Storage
        for e in events_by_cond:
            events_by_cond[e].setdefault(cond, {'r': np.zeros((Ntime_a, N)),
                                                'n': np.zeros((Ntime_a, N))})

        # Firing rates
        m_n = np.tile(M[:,n], (N,1)).T
        r_n = r[:,n]*m_n

        for e in events_by_cond:
            # Align point
            if e == 'offer':
                t0 = trial['epochs']['offer-on'][0]
            elif e == 'choice':
                t0 = perf.t_choices[n]
            else:
                raise ValueError(e)

            # Before
            n_b = r_n[:t0].shape[0]
            events_by_cond[e][cond]['r'][Ntime-1-n_b:Ntime-1] += r_n[:t0]
            events_by_cond[e][cond]['n'][Ntime-1-n_b:Ntime-1] += m_n[:t0]

            # After
            n_a = r_n[t0:].shape[0]
            events_by_cond[e][cond]['r'][Ntime-1:Ntime-1+n_a] += r_n[t0:]
            events_by_cond[e][cond]['n'][Ntime-1:Ntime-1+n_a] += m_n[t0:]
    print("Non-decision trials: {}/{}".format(n_nondecision, len(trials)))

    # Average trials
    for e in events_by_cond:
        for cond in events_by_cond[e]:
            events_by_cond[e][cond] = utils.div(events_by_cond[e][cond]['r'],
                                                events_by_cond[e][cond]['n'])

    # Epochs
    epochs = ['preoffer', 'postoffer', 'latedelay', 'prechoice']

    # Average epochs
    epochs_by_cond = {e: {} for e in epochs}
    for e in epochs_by_cond:
        if e == 'preoffer':
            ev = 'offer'
            w, = np.where((-500 <= time_a) & (time_a < 0))
        elif e == 'postoffer':
            ev = 'offer'
            w, = np.where((0 <= time_a) & (time_a < 500))
        elif e == 'latedelay':
            ev = 'offer'
            w, = np.where((500 <= time_a) & (time_a < 1000))
        elif e == 'prechoice':
            ev = 'choice'
            w, = np.where((-500 <= time_a) & (time_a < 0))
        else:
            raise ValueError(e)

        for cond in events_by_cond[ev]:
            epochs_by_cond[e][cond] = np.mean(events_by_cond[ev][cond][w], axis=0)

    #=====================================================================================
    # Classify units
    #=====================================================================================

    idpt = indifference_point(behaviorfile, offers)
    unit_types = classify_units(trials, perf, r, idpt)
    #unit_types = {}

    numbers = {}
    for v in unit_types.values():
        numbers[v] = 0
    for k, v in unit_types.items():
        numbers[v] += 1

    n_tot = np.sum(numbers.values())
    for k, v in numbers.items():
        print("{}: {}/{} = {}%".format(k, v, n_tot, 100*v/n_tot))

    #=====================================================================================
    # Plot
    #=====================================================================================

    lw  = kwargs.get('lw',  1.5)
    ms  = kwargs.get('ms',  6)
    mew = kwargs.get('mew', 0.5)
    rotation = kwargs.get('rotation', 60)
    #min_trials = kwargs.get('min_trials', 100)

    def plot_activity(plot, unit):
        yall = [1]

        min_trials = 20

        # Pre-offer
        epoch_by_cond = epochs_by_cond['preoffer']
        color = '0.7'
        if separate_by_choice:
            for choice, marker in zip(['A', 'B'], ['d', 'o']):
                x = []
                y = []
                for i, offer in enumerate(offers):
                    cond = (offer, choice)
                    if cond in n_by_cond and n_by_cond[cond] >= min_trials:
                        y_i = epoch_by_cond[cond][unit]
                        plot.plot(i, y_i, marker, mfc=color, mec=color, ms=0.8*ms,
                                  mew=0.8*mew, zorder=10)
                        yall.append(y_i)
                        if i != 0 and i != len(offers)-1:
                            x.append(i)
                            y.append(y_i)
                plot.plot(x, y, '-', color=color, lw=0.8*lw, zorder=5)
        else:
            x = []
            y = []
            for i, offer in enumerate(offers):
                y_i = epoch_by_cond[offer][unit]
                plot.plot(i, y_i, 'o', mfc=color, mec=color, ms=0.8*ms,
                          mew=0.8*mew, zorder=10)
                yall.append(y_i)
                if i != 0 and i != len(offers)-1:
                    x.append(i)
                    y.append(y_i)
            plot.plot(x, y, '-', color=color, lw=0.8*lw, zorder=5)

        # Epoch
        epoch_by_cond = epochs_by_cond[epoch]
        if epoch == 'postoffer':
            color = Figure.colors('darkblue')
        elif epoch == 'latedelay':
            color = Figure.colors('darkblue')
        elif epoch == 'prechoice':
            color = Figure.colors('darkblue')
        else:
            raise ValueError(epoch)
        if separate_by_choice:
            for choice, marker, color in zip(['A', 'B'], ['d', 'o'], [Figure.colors('red'), Figure.colors('blue')]):
                x = []
                y = []
                for i, offer in enumerate(offers):
                    cond = (offer, choice)
                    if cond in n_by_cond and n_by_cond[cond] >= min_trials:
                        y_i = epoch_by_cond[cond][unit]
                        yall.append(y_i)
                        plot.plot(i, y_i, marker, mfc=color, mec=color, ms=ms, mew=mew, zorder=10)
                        if i != 0 and i != len(offers)-1:
                            x.append(i)
                            y.append(y_i)
                plot.plot(x, y, '-', color=color, lw=lw, zorder=5)
        else:
            x = []
            y = []
            for i, offer in enumerate(offers):
                y_i = epoch_by_cond[offer][unit]
                plot.plot(i, y_i, 'o', mfc=color, mec=color, ms=ms, mew=mew, zorder=10)
                yall.append(y_i)
                if i != 0 and i != len(offers)-1:
                    x.append(i)
                    y.append(y_i)
            plot.plot(x, y, '-', color=color, lw=lw, zorder=5)

        plot.xticks(range(len(offers)))
        plot.xticklabels(['{}B:{}A'.format(*offer) for offer in offers],
                         rotation=rotation)

        plot.xlim(0, len(offers)-1)
        plot.lim('y', yall, lower=0)

        return yall

    #-------------------------------------------------------------------------------------

    if units is not None:
        for plot, unit in zip(plots, units):
            plot_activity(plot, unit)
    else:
        name = plots
        for unit in xrange(N):
            fig  = Figure()
            plot = fig.add()

            plot_activity(plot, unit)

            if separate_by_choice:
                suffix = '_sbc'
            else:
                suffix = ''

            if unit in unit_types:
                plot.text_upper_right(unit_types[unit], fontsize=9)

            fig.save(name+'_{}{}_{}{:03d}'.format(epoch, suffix, network, unit))
            fig.close()
示例#15
0
def do(action, args, config):
    """
    Manage tasks.

    """
    print("ACTION*:   " + str(action))
    print("ARGS*:     " + str(args))

    #=====================================================================================

    if 'trials' in action:
        try:
            trials_per_condition = int(args[0])
        except IndexError:
            trials_per_condition = 100

        model = config['model']
        pg = model.get_pg(config['savefile'], config['seed'], config['dt'])

        spec = model.spec
        juices = spec.juices
        offers = spec.offers
        n_conditions = spec.n_conditions
        n_trials = trials_per_condition * n_conditions

        print("{} trials".format(n_trials))
        task = model.Task()
        trials = []
        for n in xrange(n_trials):
            k = tasktools.unravel_index(n, (len(juices), len(offers)))
            context = {'juice': juices[k.pop(0)], 'offer': offers[k.pop(0)]}
            trials.append(task.get_condition(pg.rng, pg.dt, context))
        runtools.run(action, trials, pg, config['trialspath'])

    #=====================================================================================

    elif action == 'choice_pattern':
        trialsfile = runtools.behaviorfile(config['trialspath'])

        fig = Figure()
        plot = fig.add()

        spec = config['model'].spec

        choice_pattern(trialsfile, spec.offers, plot)

        plot.xlabel('Offer (\#B : \#A)')
        plot.ylabel('Percent choice B')

        plot.text_upper_left('1A = {}B'.format(spec.A_to_B), fontsize=10)

        fig.save(path=config['figspath'], name=action)
        fig.close()

    elif action == 'indifference_point':
        trialsfile = runtools.behaviorfile(config['trialspath'])

        fig = Figure()
        plot = fig.add()

        spec = config['model'].spec

        indifference_point(trialsfile, spec.offers, plot)

        plot.xlabel('$(n_B - n_A)/(n_B + n_A)$')
        plot.ylabel('Percent choice B')

        #plot.text_upper_left('1A = {}B'.format(spec.A_to_B), fontsize=10)

        fig.save(path=config['figspath'], name=action)
        fig.close()

    #=====================================================================================

    elif action == 'sort_epoch':
        behaviorfile = runtools.behaviorfile(config['trialspath'])
        activityfile = runtools.activityfile(config['trialspath'])

        epoch = args[0]

        if 'value' in args:
            network = 'v'
        else:
            network = 'p'

        separate_by_choice = ('separate-by-choice' in args)

        sort_epoch(behaviorfile,
                   activityfile,
                   epoch,
                   config['model'].spec.offers,
                   os.path.join(config['figspath'], 'sorted'),
                   network=network,
                   separate_by_choice=separate_by_choice)
示例#16
0
rho = matrixtools.spectral_radius(W)
plot_weights(plot, W)
plot.text_upper_left(r'$W_\text{rec}$', fontsize=fontsize, dy=dy)
plot.text_upper_right(r'$\rho={:.3f}$'.format(rho), fontsize=fontsize, dy=dy)
plot.xlabel('$W$')

#for j in xrange(W.shape[1]):
#    print(sum(1*(W[:,j] != 0)))

Wrec_gates = params['Wrec_gates']
if 'Wrec_gates' in masks:
    Wrec_gates *= masks['Wrec_gates']

plot = fig['Wrec_lambda']
W = Wrec_gates[:, :N]
rho = matrixtools.spectral_radius(W)
plot_weights(plot, W)
plot.text_upper_left(r'$W_\text{rec}^\lambda$', fontsize=fontsize, dy=dy)
plot.text_upper_right(r'$\rho={:.3f}$'.format(rho), fontsize=fontsize, dy=dy)

plot = fig['Wrec_gamma']
W = Wrec_gates[:, N:]
rho = matrixtools.spectral_radius(W)
plot_weights(plot, W)
plot.text_upper_left(r'$W_\text{rec}^\gamma$', fontsize=fontsize, dy=dy)
plot.text_upper_right(r'$\rho={:.3f}$'.format(rho), fontsize=fontsize, dy=dy)
#'''
#=========================================================================================

fig.save(path=figspath, name='fig_weights_' + modelname)
示例#17
0
def do(action, args, config):
    """
    Manage tasks.

    """
    print("ACTION*:   " + str(action))
    print("ARGS*:     " + str(args))

    #=====================================================================================

    if action == 'plot_trial':

        try:
            trials_per_condition = int(args[0])
        except:
            trials_per_condition = 1000
        model = config['model']
        pg = model.get_pg(config['savefile'], config['seed'], config['dt'])

        spec = model.spec
        juices = spec.juices
        offers = spec.offers
        n_conditions = spec.n_conditions
        n_trials = trials_per_condition * n_conditions

        print("{} trials".format(n_trials))
        task = model.Task()

        fig = Figure(axislabelsize=10, ticklabelsize=9)
        plot = fig.add()

        plot_trial()
        performance(config['savefile'], plot)

        fig.save(path=config['figspath'], name='performance')
        fig.close()

    #=====================================================================================

    elif 'trials' in action:
        try:
            trials_per_condition = int(args[0])
        except IndexError:
            trials_per_condition = 100

        model = config['model']
        pg = model.get_pg(config['savefile'], config['seed'], config['dt'])

        spec = model.spec
        juices = spec.juices
        offers = spec.offers
        n_conditions = spec.n_conditions
        n_trials = trials_per_condition * n_conditions

        print("{} trials".format(n_trials))
        task = model.Task()
        trials = []
        for n in xrange(n_trials):
            k = tasktools.unravel_index(n, (len(juices), len(offers)))
            context = {'juice': juices[k.pop(0)], 'offer': offers[k.pop(0)]}
            trials.append(task.get_condition(pg.rng, pg.dt, context))
        runtools.run(action, trials, pg, config['trialspath'])

    #=====================================================================================

    elif action == 'choice_pattern':

        trialsfile = runtools.behaviorfile(config['trialspath'])
        # print trialsfile
        # fig = Figure()

        # plot = fig.add()
        savefile = config['figspath']

        spec = config['model'].spec

        #print spec.offers
        choice_pattern(trialsfile, spec.offers, savefile, action)

        #plot.xlabel('Offer (\#B : \#A)')

        #plot.ylabel('Percent choice B')

        #plot.text_upper_left('1A = {}B'.format(spec.A_to_B), fontsize=10)

    #=====================================================================================

    elif action == 'sort':
        if 'value' in args:
            network = 'v'
        else:
            network = 'p'

        trialsfile = runtools.activityfile(config['trialspath'])
        sort(trialsfile, (config['figspath'], 'sorted'), network=network)

    #=====================================================================================

    elif action == 'statespace':
        trialsfile = runtools.activityfile(config['trialspath'])
        statespace(trialsfile, (config['figspath'], 'statespace'))
示例#18
0
plot.xlim(-200, 600)
plot.xticks([-200, 0, 200, 400, 600])
plot.ylim(0.5, 1.2)
#plot.yticks([0.5, 1])

plot.xlabel('Time from stimulus (ms)')
plot.ylabel('Expected reward')

# Legend
props = {
    'prop': {
        'size': 8
    },
    'handlelength': 1.2,
    'handletextpad': 1.1,
    'labelspacing': 0.7
}
plot.legend(bbox_to_anchor=(0.33, 1), **props)

plot = fig['on-choice']
plot.xlim(-400, 0)
plot.xticks([-400, -200, 0])
plot.ylim(0.5, 1.2)
#plot.yticks([0.5, 1])

plot.xlabel('Time from decision (ms)')

#=========================================================================================

fig.save()
示例#19
0
文件: romo.py 项目: sumwor/pylearning
def do(action, args, config):
    """
    Manage tasks.

    """
    print("ACTION*:   " + str(action))
    print("ARGS*:     " + str(args))

    if 'trials' in action:
        try:
            trials_per_condition = int(args[0])
        except:
            trials_per_condition = 100

        model = config['model']
        pg = model.get_pg(config['savefile'], config['seed'], config['dt'])

        spec = model.spec
        gt_lts = spec.gt_lts
        fpairs = spec.fpairs
        n_conditions = spec.n_conditions
        n_trials = trials_per_condition * n_conditions

        print("{} trials".format(n_trials))
        task = model.Task()
        trials = []
        for n in xrange(n_trials):
            k = tasktools.unravel_index(n, (len(gt_lts), len(fpairs)))
            context = {
                'delay': 3000,
                'gt_lt': gt_lts[k.pop(0)],
                'fpair': fpairs[k.pop(0)]
            }
            trials.append(task.get_condition(pg.rng, pg.dt, context))
        runtools.run(action, trials, pg, config['trialspath'])

    #=====================================================================================

    elif action == 'performance':
        trialsfile = runtools.behaviorfile(config['trialspath'])

        fig = Figure()
        plot = fig.add()

        performance(trialsfile, plot)

        plot.xlabel('$f_1$ (Hz)')
        plot.ylabel('$f_2$ (Hz)')

        fig.save(os.path.join(config['figspath'], action))

    #=====================================================================================

    elif action == 'sort':
        if 'value' in args:
            network = 'v'
        else:
            network = 'p'

        trialsfile = runtools.activityfile(config['trialspath'])
        sort(trialsfile, (config['figspath'], 'sorted'), network=network)
示例#20
0
def do(action, args, config):
    print("ACTION*:   " + str(action))
    print("ARGS*:     " + str(args))

    #=====================================================================================

    if 'trials' in action:
        try:
            trials_per_condition = int(args[0])
        except:
            trials_per_condition = 100

        model = config['model']
        pg    = model.get_pg(config['savefile'], config['seed'], config['dt'])

        spec         = model.spec
        wagers       = spec.wagers
        left_rights  = spec.left_rights
        cohs         = spec.cohs
        n_conditions = spec.n_conditions
        n_trials     = trials_per_condition * n_conditions

        print("{} trials".format(n_trials))
        task   = model.Task()
        trials = []
        for n in xrange(n_trials):
            k = tasktools.unravel_index(n, (len(wagers), len(left_rights), len(cohs)))
            context = {
                'wager':      wagers[k.pop(0)],
                'left_right': left_rights[k.pop(0)],
                'coh':        cohs[k.pop(0)]
                }
            trials.append(task.get_condition(pg.rng, pg.dt, context))
        runtools.run(action, trials, pg, config['trialspath'])

    #=====================================================================================

    elif action == 'sure_stimulus_duration':
        trialsfile = runtools.behaviorfile(config['trialspath'])

        fig  = Figure()
        plot = fig.add()

        sure_stimulus_duration(trialsfile, plot)

        plot.xlabel('Stimulus duration (ms)')
        plot.ylabel('Probability sure target')

        fig.save(path=config['figspath'], name=action)

    #=====================================================================================

    elif action == 'correct_stimulus_duration':
        trialsfile = runtools.behaviorfile(config['trialspath'])

        fig  = Figure()
        plot = fig.add()

        correct_stimulus_duration(trialsfile, plot)

        plot.xlabel('Stimulus duration (ms)')
        plot.ylabel('Probability correct')

        fig.save(path=config['figspath'], name=action)

    #=====================================================================================

    elif action == 'value_stimulus_duration':
        trialsfile = runtools.activityfile(config['trialspath'])

        fig  = Figure()
        plot = fig.add()

        value_stimulus_duration(trialsfile, plot)

        plot.xlabel('Stimulus duration (ms)')
        plot.ylabel('Expected reward')

        fig.save(path=config['figspath'], name=action)

    #=====================================================================================

    elif action == 'sort':
        if 'value' in args:
            network = 'v'
        else:
            network = 'p'

        trialsfile = runtools.activityfile(config['trialspath'])
        sort(trialsfile, os.path.join(config['figspath'], 'sorted'), network=network)
示例#21
0
def sort(trialsfile, plots, units=None, network='p', **kwargs):
    """
    Sort trials.

    """
    # Load trials
    data = utils.load(trialsfile)
    trials, U, Z, Z_b, A, P, M, perf, r_p, r_v = data

    # Which network?
    if network == 'p':
        r = r_p
    else:
        r = r_v

    # Number of units
    N = r.shape[-1]

    # Same for every trial
    time  = trials[0]['time']
    Ntime = len(time)

    # Aligned time
    time_a  = np.concatenate((-time[1:][::-1], time))
    Ntime_a = len(time_a)

    #=====================================================================================
    # Aligned to stimulus onset
    #=====================================================================================

    r_by_cond_stimulus   = {}
    n_r_by_cond_stimulus = {}
    for n, trial in enumerate(trials):
        if not perf.decisions[n]:
            continue

        if trial['mod'] == 'va':
            continue
        assert trial['mod'] == 'v' or trial['mod'] == 'a'

        if not perf.corrects[n]:
            continue

        # Condition
        mod    = trial['mod']
        choice = perf.choices[n]
        cond   = (mod, choice)

        # Storage
        r_by_cond_stimulus.setdefault(cond, np.zeros((Ntime_a, N)))
        n_r_by_cond_stimulus.setdefault(cond, np.zeros((Ntime_a, N)))

        # Firing rates
        Mn = np.tile(M[:,n], (N,1)).T
        Rn = r[:,n]*Mn

        # Align point
        t0 = trial['epochs']['stimulus'][0] - 1

        # Before
        n_b = Rn[:t0].shape[0]
        r_by_cond_stimulus[cond][Ntime-1-n_b:Ntime-1]   += Rn[:t0]
        n_r_by_cond_stimulus[cond][Ntime-1-n_b:Ntime-1] += Mn[:t0]

        # After
        n_a = Rn[t0:].shape[0]
        r_by_cond_stimulus[cond][Ntime-1:Ntime-1+n_a]   += Rn[t0:]
        n_r_by_cond_stimulus[cond][Ntime-1:Ntime-1+n_a] += Mn[t0:]

    for cond in r_by_cond_stimulus:
        r_by_cond_stimulus[cond] = utils.div(r_by_cond_stimulus[cond],
                                             n_r_by_cond_stimulus[cond])

    #-------------------------------------------------------------------------------------
    # Plot
    #-------------------------------------------------------------------------------------

    lw     = kwargs.get('lw', 1.5)
    dashes = kwargs.get('dashes', [3, 2])

    vline_props = {'lw': kwargs.get('lw_vline', 0.5)}
    if 'dashes_vline' in kwargs:
        vline_props['linestyle'] = '--'
        vline_props['dashes']    = dashes

    colors_by_mod = {
        'v': Figure.colors('blue'),
        'a': Figure.colors('green')
        }
    linestyle_by_choice = {
        'L': '-',
        'H': '--'
        }
    lineprops = dict(lw=lw)

    def plot_sorted(plot, unit, w, r_sorted):
        t = time_a[w]
        yall = [[1]]
        for cond in [('v', 'H'), ('v', 'L'), ('a', 'H'), ('a', 'L')]:
            mod, choice = cond

            if mod == 'v':
                label = 'Vis, '
            elif mod == 'a':
                label = 'Aud, '
            else:
                raise ValueError(mod)

            if choice == 'H':
                label += 'high'
            elif choice == 'L':
                label += 'low'
            else:
                raise ValueError(choice)

            linestyle = linestyle_by_choice[choice]
            if linestyle == '-':
                lineprops = dict(linestyle=linestyle, lw=lw)
            else:
                lineprops = dict(linestyle=linestyle, lw=lw, dashes=dashes)
            plot.plot(t, r_sorted[cond][w,unit],
                      color=colors_by_mod[mod],
                      label=label,
                      **lineprops)
            yall.append(r_sorted[cond][w,unit])

        return t, yall

    def on_stimulus(plot, unit):
        w, = np.where((time_a >= -300) & (time_a <= 1000))
        t, yall = plot_sorted(plot, unit, w, r_by_cond_stimulus)

        plot.xlim(t[0], t[-1])

        return yall

    if units is not None:
        for plot, unit in zip(plots, units):
            on_stimulus(plot, unit)
    else:
        figspath, name = plots
        for unit in xrange(N):
            fig  = Figure()
            plot = fig.add()

            #-----------------------------------------------------------------------------

            yall = []
            yall += on_stimulus(plot, unit)

            plot.lim('y', yall, lower=0)
            plot.vline(0)

            plot.xlabel('Time (ms)')
            plot.ylabel('Firing rate (a.u.)')

            #-----------------------------------------------------------------------------

            fig.save(path=figspath, name=name+'_{}{:03d}'.format(network, unit))
            fig.close()