Beispiel #1
0
    def get_wager(func_t0):
        trials_by_cond      = {}
        trials_by_cond_sure = {}
        for n, trial in enumerate(trials):
            if not trial['wager']:
                continue

            if perf.choices[n] is None:
                continue

            if trial['coh'] == 0:
                continue

            cond = trial['left_right']

            m_n = np.tile(M[:,n], (N, 1)).T
            r_n = r[:,n]*m_n

            t0 = func_t0(trial['epochs'], perf.t_choices[n])

            if perf.choices[n] == 'S':
                # Storage
                trials_by_cond_sure.setdefault(cond, {'r': np.zeros((Ntime_a, N)),
                                                      'n': np.zeros((Ntime_a, N))})

                # Before
                n_b = r_n[:t0].shape[0]
                trials_by_cond_sure[cond]['r'][Ntime-1-n_b:Ntime-1] += r_n[:t0]
                trials_by_cond_sure[cond]['n'][Ntime-1-n_b:Ntime-1] += m_n[:t0]

                # After
                n_a = r_n[t0:].shape[0]
                trials_by_cond_sure[cond]['r'][Ntime-1:Ntime-1+n_a] += r_n[t0:]
                trials_by_cond_sure[cond]['n'][Ntime-1:Ntime-1+n_a] += m_n[t0:]
            else:
                # Storage
                trials_by_cond.setdefault(cond, {'r': np.zeros((Ntime_a, N)),
                                                 'n': np.zeros((Ntime_a, N))})

                # Before
                n_b = r_n[:t0].shape[0]
                trials_by_cond[cond]['r'][Ntime-1-n_b:Ntime-1] += r_n[:t0]
                trials_by_cond[cond]['n'][Ntime-1-n_b:Ntime-1] += m_n[:t0]

                # After
                n_a = r_n[t0:].shape[0]
                trials_by_cond[cond]['r'][Ntime-1:Ntime-1+n_a] += r_n[t0:]
                trials_by_cond[cond]['n'][Ntime-1:Ntime-1+n_a] += m_n[t0:]

        # Average
        for cond in trials_by_cond:
            trials_by_cond[cond] = utils.div(trials_by_cond[cond]['r'],
                                             trials_by_cond[cond]['n'])

        # Average
        for cond in trials_by_cond_sure:
            trials_by_cond_sure[cond] = utils.div(trials_by_cond_sure[cond]['r'],
                                                  trials_by_cond_sure[cond]['n'])

        return trials_by_cond, trials_by_cond_sure
Beispiel #2
0
    def get_wager(func_t0):
        trials_by_cond      = {}
        trials_by_cond_sure = {}
        for n, trial in enumerate(trials):
            if not trial['wager']:
                continue

            if perf.choices[n] is None:
                continue

            if trial['coh'] == 0:
                continue

            cond = trial['left_right']

            m_n = np.tile(M[:,n], (N, 1)).T
            r_n = r[:,n]*m_n

            t0 = func_t0(trial['epochs'], perf.t_choices[n])

            if perf.choices[n] == 'S':
                # Storage
                trials_by_cond_sure.setdefault(cond, {'r': np.zeros((Ntime_a, N)),
                                                      'n': np.zeros((Ntime_a, N))})

                # Before
                n_b = r_n[:t0].shape[0]
                trials_by_cond_sure[cond]['r'][Ntime-1-n_b:Ntime-1] += r_n[:t0]
                trials_by_cond_sure[cond]['n'][Ntime-1-n_b:Ntime-1] += m_n[:t0]

                # After
                n_a = r_n[t0:].shape[0]
                trials_by_cond_sure[cond]['r'][Ntime-1:Ntime-1+n_a] += r_n[t0:]
                trials_by_cond_sure[cond]['n'][Ntime-1:Ntime-1+n_a] += m_n[t0:]
            else:
                # Storage
                trials_by_cond.setdefault(cond, {'r': np.zeros((Ntime_a, N)),
                                                 'n': np.zeros((Ntime_a, N))})

                # Before
                n_b = r_n[:t0].shape[0]
                trials_by_cond[cond]['r'][Ntime-1-n_b:Ntime-1] += r_n[:t0]
                trials_by_cond[cond]['n'][Ntime-1-n_b:Ntime-1] += m_n[:t0]

                # After
                n_a = r_n[t0:].shape[0]
                trials_by_cond[cond]['r'][Ntime-1:Ntime-1+n_a] += r_n[t0:]
                trials_by_cond[cond]['n'][Ntime-1:Ntime-1+n_a] += m_n[t0:]

        # Average
        for cond in trials_by_cond:
            trials_by_cond[cond] = utils.div(trials_by_cond[cond]['r'],
                                             trials_by_cond[cond]['n'])

        # Average
        for cond in trials_by_cond_sure:
            trials_by_cond_sure[cond] = utils.div(trials_by_cond_sure[cond]['r'],
                                                  trials_by_cond_sure[cond]['n'])

        return trials_by_cond, trials_by_cond_sure
Beispiel #3
0
def compute_dprime(trials, perf, r):
    """
    Compute d' for choice.

    """
    N  = r.shape[-1]
    L  = np.zeros(N)
    L2 = np.zeros(N)
    R  = np.zeros(N)
    R2 = np.zeros(N)
    nL = 0
    nR = 0

    for n, trial in enumerate(trials):
        if perf.choices[n] is None:
            continue

        stimulus = trial['epochs']['stimulus']
        r_n      = r[stimulus,n]

        if trial['left_right'] < 0:
            L  += np.sum(r_n,    axis=0)
            L2 += np.sum(r_n**2, axis=0)
            nL += r_n.shape[0]
        else:
            R  += np.sum(r_n,    axis=0)
            R2 += np.sum(r_n**2, axis=0)
            nR += r_n.shape[0]

    mean_L = L/nL
    var_L  = L2/nL - mean_L**2

    mean_R = R/nR
    var_R  = R2/nR - mean_R**2

    return -utils.div(mean_L - mean_R, np.sqrt((var_L + var_R)/2))
Beispiel #4
0
def compute_dprime(trials, perf, r):
    """
    Compute d' for choice.

    """
    N  = r.shape[-1]
    L  = np.zeros(N)
    L2 = np.zeros(N)
    R  = np.zeros(N)
    R2 = np.zeros(N)
    nL = 0
    nR = 0

    for n, trial in enumerate(trials):
        if perf.choices[n] is None:
            continue

        stimulus = trial['epochs']['stimulus']
        r_n      = r[stimulus,n]

        if trial['left_right'] < 0:
            L  += np.sum(r_n,    axis=0)
            L2 += np.sum(r_n**2, axis=0)
            nL += r_n.shape[0]
        else:
            R  += np.sum(r_n,    axis=0)
            R2 += np.sum(r_n**2, axis=0)
            nR += r_n.shape[0]

    mean_L = L/nL
    var_L  = L2/nL - mean_L**2

    mean_R = R/nR
    var_R  = R2/nR - mean_R**2

    return -utils.div(mean_L - mean_R, np.sqrt((var_L + var_R)/2))
Beispiel #5
0
def sort(trialsfile, plots, units=None, network='p', **kwargs):
    """
    Sort trials.

    """
    # Load trials
    data = utils.load(trialsfile)
    trials, U, Z, Z_b, A, P, M, perf, r_p, r_v = data

    # Which network?
    if network == 'p':
        r = r_p
    else:
        r = r_v

    # Number of units
    N = r.shape[-1]

    # Same for every trial
    time = trials[0]['time']
    Ntime = len(time)

    # Aligned time
    time_a = np.concatenate((-time[1:][::-1], time))
    Ntime_a = len(time_a)

    #=====================================================================================
    # Aligned to stimulus onset
    #=====================================================================================

    r_by_cond_stimulus = {}
    n_r_by_cond_stimulus = {}
    for n, trial in enumerate(trials):
        if not perf.decisions[n]:
            continue

        if trial['mod'] == 'va':
            continue
        assert trial['mod'] == 'v' or trial['mod'] == 'a'

        if not perf.corrects[n]:
            continue

        # Condition
        mod = trial['mod']
        choice = perf.choices[n]
        cond = (mod, choice)

        # Storage
        r_by_cond_stimulus.setdefault(cond, np.zeros((Ntime_a, N)))
        n_r_by_cond_stimulus.setdefault(cond, np.zeros((Ntime_a, N)))

        # Firing rates
        Mn = np.tile(M[:, n], (N, 1)).T
        Rn = r[:, n] * Mn

        # Align point
        t0 = trial['epochs']['stimulus'][0] - 1

        # Before
        n_b = Rn[:t0].shape[0]
        r_by_cond_stimulus[cond][Ntime - 1 - n_b:Ntime - 1] += Rn[:t0]
        n_r_by_cond_stimulus[cond][Ntime - 1 - n_b:Ntime - 1] += Mn[:t0]

        # After
        n_a = Rn[t0:].shape[0]
        r_by_cond_stimulus[cond][Ntime - 1:Ntime - 1 + n_a] += Rn[t0:]
        n_r_by_cond_stimulus[cond][Ntime - 1:Ntime - 1 + n_a] += Mn[t0:]

    for cond in r_by_cond_stimulus:
        r_by_cond_stimulus[cond] = utils.div(r_by_cond_stimulus[cond],
                                             n_r_by_cond_stimulus[cond])

    #-------------------------------------------------------------------------------------
    # Plot
    #-------------------------------------------------------------------------------------

    lw = kwargs.get('lw', 1.5)
    dashes = kwargs.get('dashes', [3, 2])

    vline_props = {'lw': kwargs.get('lw_vline', 0.5)}
    if 'dashes_vline' in kwargs:
        vline_props['linestyle'] = '--'
        vline_props['dashes'] = dashes

    colors_by_mod = {'v': Figure.colors('blue'), 'a': Figure.colors('green')}
    linestyle_by_choice = {'L': '-', 'H': '--'}
    lineprops = dict(lw=lw)

    def plot_sorted(plot, unit, w, r_sorted):
        t = time_a[w]
        yall = [[1]]
        for cond in [('v', 'H'), ('v', 'L'), ('a', 'H'), ('a', 'L')]:
            mod, choice = cond

            if mod == 'v':
                label = 'Vis, '
            elif mod == 'a':
                label = 'Aud, '
            else:
                raise ValueError(mod)

            if choice == 'H':
                label += 'high'
            elif choice == 'L':
                label += 'low'
            else:
                raise ValueError(choice)

            linestyle = linestyle_by_choice[choice]
            if linestyle == '-':
                lineprops = dict(linestyle=linestyle, lw=lw)
            else:
                lineprops = dict(linestyle=linestyle, lw=lw, dashes=dashes)
            plot.plot(t,
                      r_sorted[cond][w, unit],
                      color=colors_by_mod[mod],
                      label=label,
                      **lineprops)
            yall.append(r_sorted[cond][w, unit])

        return t, yall

    def on_stimulus(plot, unit):
        w, = np.where((time_a >= -300) & (time_a <= 1000))
        t, yall = plot_sorted(plot, unit, w, r_by_cond_stimulus)

        plot.xlim(t[0], t[-1])

        return yall

    if units is not None:
        for plot, unit in zip(plots, units):
            on_stimulus(plot, unit)
    else:
        figspath, name = plots
        for unit in xrange(N):
            fig = Figure()
            plot = fig.add()

            #-----------------------------------------------------------------------------

            yall = []
            yall += on_stimulus(plot, unit)

            plot.lim('y', yall, lower=0)
            plot.vline(0)

            plot.xlabel('Time (ms)')
            plot.ylabel('Firing rate (a.u.)')

            #-----------------------------------------------------------------------------

            fig.save(path=figspath,
                     name=name + '_{}{:03d}'.format(network, unit))
            fig.close()
Beispiel #6
0
def sort(trialsfile, plots, units=None, network='p', **kwargs):
    """
    Sort trials.

    """
    # Load trials
    data = utils.load(trialsfile)
    if len(data) == 9:
        trials, U, Z, A, P, M, perf, r_p, r_v = data
    else:
        trials, U, Z, Z_b, A, P, M, perf, r_p, r_v = data

    # Which network?
    if network == 'p':
        r = r_p
    else:
        r = r_v

    # Data shape
    Ntime = r.shape[0]
    N = r.shape[-1]

    # Same for every trial
    time = trials[0]['time']

    # Aligned time
    time_a = np.concatenate((-time[1:][::-1], time))
    Ntime_a = len(time_a)

    #=====================================================================================
    # Sort trials
    #=====================================================================================

    # Sort
    trials_by_cond = {}
    for n, trial in enumerate(trials):
        if perf.choices[n] is None or not perf.corrects[n]:
            continue

        # Condition
        gt_lt = trial['gt_lt']
        fpair = trial['fpair']
        if gt_lt == '>':
            f1, f2 = fpair
        else:
            f2, f1 = fpair
        cond = (f1, f2)

        # Firing rates
        Mn = np.tile(M[:, n], (N, 1)).T
        Rn = r[:, n] * Mn

        # Align point
        t0 = trial['epochs']['f1'][0] - 1

        # Storage
        trials_by_cond.setdefault(cond, {
            'r': np.zeros((Ntime_a, N)),
            'n': np.zeros((Ntime_a, N))
        })

        # Before
        n_b = Rn[:t0].shape[0]
        trials_by_cond[cond]['r'][Ntime - 1 - n_b:Ntime - 1] += Rn[:t0]
        trials_by_cond[cond]['n'][Ntime - 1 - n_b:Ntime - 1] += Mn[:t0]

        # After
        n_a = Rn[t0:].shape[0]
        trials_by_cond[cond]['r'][Ntime - 1:Ntime - 1 + n_a] += Rn[t0:]
        trials_by_cond[cond]['n'][Ntime - 1:Ntime - 1 + n_a] += Mn[t0:]

    # Average
    for cond in trials_by_cond:
        trials_by_cond[cond] = utils.div(trials_by_cond[cond]['r'],
                                         trials_by_cond[cond]['n'])

    #=====================================================================================
    # Plot
    #=====================================================================================

    lw = kwargs.get('lw', 1.5)

    w, = np.where((time_a >= -500) & (time_a <= 4000))

    def plot_sorted(plot, unit):
        t = 1e-3 * time_a[w]
        yall = [[1]]
        for (f1, f2), r in trials_by_cond.items():
            plot.plot(t, r[w, unit], color=smap.to_rgba(f1), lw=lw)
            yall.append(r[w, unit])

        return t, yall

    if units is not None:
        for plot, unit in zip(plots, units):
            plot_sorted(plot, unit)
    else:
        figspath, name = plots
        for unit in xrange(N):
            fig = Figure()
            plot = fig.add()

            #-----------------------------------------------------------------------------

            t, yall = plot_sorted(plot, unit)

            plot.xlim(t[0], t[-1])
            plot.lim('y', yall, lower=0)

            plot.highlight(0, 0.5)
            plot.highlight(3.5, 4)

            #-----------------------------------------------------------------------------

            fig.save(path=figspath,
                     name=name + '_{}{:03d}'.format(network, unit))
            fig.close()
Beispiel #7
0
def sort_epoch(behaviorfile,
               activityfile,
               epoch,
               offers,
               plots,
               units=None,
               network='p',
               separate_by_choice=False,
               **kwargs):
    """
    Sort trials.

    """
    # Load trials
    data = utils.load(activityfile)
    trials, U, Z, Z_b, A, P, M, perf, r_p, r_v = data

    if network == 'p':
        print("POLICY NETWORK")
        r = r_p
    else:
        print("VALUE NETWORK")
        r = r_v

    # Number of units
    N = r.shape[-1]

    # Same for every trial
    time = trials[0]['time']
    Ntime = len(time)

    # Aligned time
    time_a = np.concatenate((-time[1:][::-1], time))
    Ntime_a = len(time_a)

    #=====================================================================================
    # Sort trials
    #=====================================================================================

    # Epochs
    events = ['offer', 'choice']

    # Sort
    events_by_cond = {e: {} for e in events}
    n_by_cond = {}
    n_nondecision = 0
    for n, trial in enumerate(trials):
        if perf.choices[n] is None:
            n_nondecision += 1
            continue

        # Condition
        offer = trial['offer']
        choice = perf.choices[n]

        if separate_by_choice:
            cond = (offer, choice)
        else:
            cond = offer

        n_by_cond.setdefault(cond, 0)
        n_by_cond[cond] += 1

        # Storage
        for e in events_by_cond:
            events_by_cond[e].setdefault(cond, {
                'r': np.zeros((Ntime_a, N)),
                'n': np.zeros((Ntime_a, N))
            })

        # Firing rates
        m_n = np.tile(M[:, n], (N, 1)).T
        r_n = r[:, n] * m_n

        for e in events_by_cond:
            # Align point
            if e == 'offer':
                t0 = trial['epochs']['offer-on'][0]
            elif e == 'choice':
                t0 = perf.t_choices[n]
            else:
                raise ValueError(e)

            # Before
            n_b = r_n[:t0].shape[0]
            events_by_cond[e][cond]['r'][Ntime - 1 - n_b:Ntime - 1] += r_n[:t0]
            events_by_cond[e][cond]['n'][Ntime - 1 - n_b:Ntime - 1] += m_n[:t0]

            # After
            n_a = r_n[t0:].shape[0]
            events_by_cond[e][cond]['r'][Ntime - 1:Ntime - 1 + n_a] += r_n[t0:]
            events_by_cond[e][cond]['n'][Ntime - 1:Ntime - 1 + n_a] += m_n[t0:]
    print("Non-decision trials: {}/{}".format(n_nondecision, len(trials)))

    # Average trials
    for e in events_by_cond:
        for cond in events_by_cond[e]:
            events_by_cond[e][cond] = utils.div(events_by_cond[e][cond]['r'],
                                                events_by_cond[e][cond]['n'])

    # Epochs
    epochs = ['preoffer', 'postoffer', 'latedelay', 'prechoice']

    # Average epochs
    epochs_by_cond = {e: {} for e in epochs}
    for e in epochs_by_cond:
        if e == 'preoffer':
            ev = 'offer'
            w, = np.where((-500 <= time_a) & (time_a < 0))
        elif e == 'postoffer':
            ev = 'offer'
            w, = np.where((0 <= time_a) & (time_a < 500))
        elif e == 'latedelay':
            ev = 'offer'
            w, = np.where((500 <= time_a) & (time_a < 1000))
        elif e == 'prechoice':
            ev = 'choice'
            w, = np.where((-500 <= time_a) & (time_a < 0))
        else:
            raise ValueError(e)

        for cond in events_by_cond[ev]:
            epochs_by_cond[e][cond] = np.mean(events_by_cond[ev][cond][w],
                                              axis=0)

    #=====================================================================================
    # Classify units
    #=====================================================================================

    idpt = indifference_point(behaviorfile, offers)
    unit_types = classify_units(trials, perf, r, idpt)
    #unit_types = {}

    numbers = {}
    for v in unit_types.values():
        numbers[v] = 0
    for k, v in unit_types.items():
        numbers[v] += 1

    n_tot = np.sum(numbers.values())
    for k, v in numbers.items():
        print("{}: {}/{} = {}%".format(k, v, n_tot, 100 * v / n_tot))

    #=====================================================================================
    # Plot
    #=====================================================================================

    lw = kwargs.get('lw', 1.5)
    ms = kwargs.get('ms', 6)
    mew = kwargs.get('mew', 0.5)
    rotation = kwargs.get('rotation', 60)

    #min_trials = kwargs.get('min_trials', 100)

    def plot_activity(plot, unit):
        yall = [1]

        min_trials = 20

        # Pre-offer
        epoch_by_cond = epochs_by_cond['preoffer']
        color = '0.7'
        if separate_by_choice:
            for choice, marker in zip(['A', 'B'], ['d', 'o']):
                x = []
                y = []
                for i, offer in enumerate(offers):
                    cond = (offer, choice)
                    if cond in n_by_cond and n_by_cond[cond] >= min_trials:
                        y_i = epoch_by_cond[cond][unit]
                        plot.plot(i,
                                  y_i,
                                  marker,
                                  mfc=color,
                                  mec=color,
                                  ms=0.8 * ms,
                                  mew=0.8 * mew,
                                  zorder=10)
                        yall.append(y_i)
                        if i != 0 and i != len(offers) - 1:
                            x.append(i)
                            y.append(y_i)
                plot.plot(x, y, '-', color=color, lw=0.8 * lw, zorder=5)
        else:
            x = []
            y = []
            for i, offer in enumerate(offers):
                y_i = epoch_by_cond[offer][unit]
                plot.plot(i,
                          y_i,
                          'o',
                          mfc=color,
                          mec=color,
                          ms=0.8 * ms,
                          mew=0.8 * mew,
                          zorder=10)
                yall.append(y_i)
                if i != 0 and i != len(offers) - 1:
                    x.append(i)
                    y.append(y_i)
            plot.plot(x, y, '-', color=color, lw=0.8 * lw, zorder=5)

        # Epoch
        epoch_by_cond = epochs_by_cond[epoch]
        if epoch == 'postoffer':
            color = Figure.colors('darkblue')
        elif epoch == 'latedelay':
            color = Figure.colors('darkblue')
        elif epoch == 'prechoice':
            color = Figure.colors('darkblue')
        else:
            raise ValueError(epoch)
        if separate_by_choice:
            for choice, marker, color in zip(
                ['A', 'B'], ['d', 'o'],
                [Figure.colors('red'),
                 Figure.colors('blue')]):
                x = []
                y = []
                for i, offer in enumerate(offers):
                    cond = (offer, choice)
                    if cond in n_by_cond and n_by_cond[cond] >= min_trials:
                        y_i = epoch_by_cond[cond][unit]
                        yall.append(y_i)
                        plot.plot(i,
                                  y_i,
                                  marker,
                                  mfc=color,
                                  mec=color,
                                  ms=ms,
                                  mew=mew,
                                  zorder=10)
                        if i != 0 and i != len(offers) - 1:
                            x.append(i)
                            y.append(y_i)
                plot.plot(x, y, '-', color=color, lw=lw, zorder=5)
        else:
            x = []
            y = []
            for i, offer in enumerate(offers):
                y_i = epoch_by_cond[offer][unit]
                plot.plot(i,
                          y_i,
                          'o',
                          mfc=color,
                          mec=color,
                          ms=ms,
                          mew=mew,
                          zorder=10)
                yall.append(y_i)
                if i != 0 and i != len(offers) - 1:
                    x.append(i)
                    y.append(y_i)
            plot.plot(x, y, '-', color=color, lw=lw, zorder=5)

        plot.xticks(range(len(offers)))
        plot.xticklabels(['{}B:{}A'.format(*offer) for offer in offers],
                         rotation=rotation)

        plot.xlim(0, len(offers) - 1)
        plot.lim('y', yall, lower=0)

        return yall

    #-------------------------------------------------------------------------------------

    if units is not None:
        for plot, unit in zip(plots, units):
            plot_activity(plot, unit)
    else:
        name = plots
        for unit in xrange(N):
            fig = Figure()
            plot = fig.add()

            plot_activity(plot, unit)

            if separate_by_choice:
                suffix = '_sbc'
            else:
                suffix = ''

            if unit in unit_types:
                plot.text_upper_right(unit_types[unit], fontsize=9)

            fig.save(name +
                     '_{}{}_{}{:03d}'.format(epoch, suffix, network, unit))
            fig.close()
Beispiel #8
0
def sort_epoch(behaviorfile, activityfile, epoch, offers, plots, units=None, network='p',
               separate_by_choice=False, **kwargs):
    """
    Sort trials.

    """
    # Load trials
    data = utils.load(activityfile)
    trials, U, Z, Z_b, A, P, M, perf, r_p, r_v = data

    if network == 'p':
        print("POLICY NETWORK")
        r = r_p
    else:
        print("VALUE NETWORK")
        r = r_v

    # Number of units
    N = r.shape[-1]

    # Same for every trial
    time  = trials[0]['time']
    Ntime = len(time)

    # Aligned time
    time_a  = np.concatenate((-time[1:][::-1], time))
    Ntime_a = len(time_a)

    #=====================================================================================
    # Sort trials
    #=====================================================================================

    # Epochs
    events = ['offer', 'choice']

    # Sort
    events_by_cond = {e: {} for e in events}
    n_by_cond      = {}
    n_nondecision  = 0
    for n, trial in enumerate(trials):
        if perf.choices[n] is None:
            n_nondecision += 1
            continue

        # Condition
        offer  = trial['offer']
        choice = perf.choices[n]

        if separate_by_choice:
            cond = (offer, choice)
        else:
            cond = offer

        n_by_cond.setdefault(cond, 0)
        n_by_cond[cond] += 1

        # Storage
        for e in events_by_cond:
            events_by_cond[e].setdefault(cond, {'r': np.zeros((Ntime_a, N)),
                                                'n': np.zeros((Ntime_a, N))})

        # Firing rates
        m_n = np.tile(M[:,n], (N,1)).T
        r_n = r[:,n]*m_n

        for e in events_by_cond:
            # Align point
            if e == 'offer':
                t0 = trial['epochs']['offer-on'][0]
            elif e == 'choice':
                t0 = perf.t_choices[n]
            else:
                raise ValueError(e)

            # Before
            n_b = r_n[:t0].shape[0]
            events_by_cond[e][cond]['r'][Ntime-1-n_b:Ntime-1] += r_n[:t0]
            events_by_cond[e][cond]['n'][Ntime-1-n_b:Ntime-1] += m_n[:t0]

            # After
            n_a = r_n[t0:].shape[0]
            events_by_cond[e][cond]['r'][Ntime-1:Ntime-1+n_a] += r_n[t0:]
            events_by_cond[e][cond]['n'][Ntime-1:Ntime-1+n_a] += m_n[t0:]
    print("Non-decision trials: {}/{}".format(n_nondecision, len(trials)))

    # Average trials
    for e in events_by_cond:
        for cond in events_by_cond[e]:
            events_by_cond[e][cond] = utils.div(events_by_cond[e][cond]['r'],
                                                events_by_cond[e][cond]['n'])

    # Epochs
    epochs = ['preoffer', 'postoffer', 'latedelay', 'prechoice']

    # Average epochs
    epochs_by_cond = {e: {} for e in epochs}
    for e in epochs_by_cond:
        if e == 'preoffer':
            ev = 'offer'
            w, = np.where((-500 <= time_a) & (time_a < 0))
        elif e == 'postoffer':
            ev = 'offer'
            w, = np.where((0 <= time_a) & (time_a < 500))
        elif e == 'latedelay':
            ev = 'offer'
            w, = np.where((500 <= time_a) & (time_a < 1000))
        elif e == 'prechoice':
            ev = 'choice'
            w, = np.where((-500 <= time_a) & (time_a < 0))
        else:
            raise ValueError(e)

        for cond in events_by_cond[ev]:
            epochs_by_cond[e][cond] = np.mean(events_by_cond[ev][cond][w], axis=0)

    #=====================================================================================
    # Classify units
    #=====================================================================================

    idpt = indifference_point(behaviorfile, offers)
    unit_types = classify_units(trials, perf, r, idpt)
    #unit_types = {}

    numbers = {}
    for v in unit_types.values():
        numbers[v] = 0
    for k, v in unit_types.items():
        numbers[v] += 1

    n_tot = np.sum(numbers.values())
    for k, v in numbers.items():
        print("{}: {}/{} = {}%".format(k, v, n_tot, 100*v/n_tot))

    #=====================================================================================
    # Plot
    #=====================================================================================

    lw  = kwargs.get('lw',  1.5)
    ms  = kwargs.get('ms',  6)
    mew = kwargs.get('mew', 0.5)
    rotation = kwargs.get('rotation', 60)
    #min_trials = kwargs.get('min_trials', 100)

    def plot_activity(plot, unit):
        yall = [1]

        min_trials = 20

        # Pre-offer
        epoch_by_cond = epochs_by_cond['preoffer']
        color = '0.7'
        if separate_by_choice:
            for choice, marker in zip(['A', 'B'], ['d', 'o']):
                x = []
                y = []
                for i, offer in enumerate(offers):
                    cond = (offer, choice)
                    if cond in n_by_cond and n_by_cond[cond] >= min_trials:
                        y_i = epoch_by_cond[cond][unit]
                        plot.plot(i, y_i, marker, mfc=color, mec=color, ms=0.8*ms,
                                  mew=0.8*mew, zorder=10)
                        yall.append(y_i)
                        if i != 0 and i != len(offers)-1:
                            x.append(i)
                            y.append(y_i)
                plot.plot(x, y, '-', color=color, lw=0.8*lw, zorder=5)
        else:
            x = []
            y = []
            for i, offer in enumerate(offers):
                y_i = epoch_by_cond[offer][unit]
                plot.plot(i, y_i, 'o', mfc=color, mec=color, ms=0.8*ms,
                          mew=0.8*mew, zorder=10)
                yall.append(y_i)
                if i != 0 and i != len(offers)-1:
                    x.append(i)
                    y.append(y_i)
            plot.plot(x, y, '-', color=color, lw=0.8*lw, zorder=5)

        # Epoch
        epoch_by_cond = epochs_by_cond[epoch]
        if epoch == 'postoffer':
            color = Figure.colors('darkblue')
        elif epoch == 'latedelay':
            color = Figure.colors('darkblue')
        elif epoch == 'prechoice':
            color = Figure.colors('darkblue')
        else:
            raise ValueError(epoch)
        if separate_by_choice:
            for choice, marker, color in zip(['A', 'B'], ['d', 'o'], [Figure.colors('red'), Figure.colors('blue')]):
                x = []
                y = []
                for i, offer in enumerate(offers):
                    cond = (offer, choice)
                    if cond in n_by_cond and n_by_cond[cond] >= min_trials:
                        y_i = epoch_by_cond[cond][unit]
                        yall.append(y_i)
                        plot.plot(i, y_i, marker, mfc=color, mec=color, ms=ms, mew=mew, zorder=10)
                        if i != 0 and i != len(offers)-1:
                            x.append(i)
                            y.append(y_i)
                plot.plot(x, y, '-', color=color, lw=lw, zorder=5)
        else:
            x = []
            y = []
            for i, offer in enumerate(offers):
                y_i = epoch_by_cond[offer][unit]
                plot.plot(i, y_i, 'o', mfc=color, mec=color, ms=ms, mew=mew, zorder=10)
                yall.append(y_i)
                if i != 0 and i != len(offers)-1:
                    x.append(i)
                    y.append(y_i)
            plot.plot(x, y, '-', color=color, lw=lw, zorder=5)

        plot.xticks(range(len(offers)))
        plot.xticklabels(['{}B:{}A'.format(*offer) for offer in offers],
                         rotation=rotation)

        plot.xlim(0, len(offers)-1)
        plot.lim('y', yall, lower=0)

        return yall

    #-------------------------------------------------------------------------------------

    if units is not None:
        for plot, unit in zip(plots, units):
            plot_activity(plot, unit)
    else:
        name = plots
        for unit in xrange(N):
            fig  = Figure()
            plot = fig.add()

            plot_activity(plot, unit)

            if separate_by_choice:
                suffix = '_sbc'
            else:
                suffix = ''

            if unit in unit_types:
                plot.text_upper_right(unit_types[unit], fontsize=9)

            fig.save(name+'_{}{}_{}{:03d}'.format(epoch, suffix, network, unit))
            fig.close()
Beispiel #9
0
def sort(trialsfile, plots, units=None, network='p', **kwargs):
    """
    Sort trials.

    """
    # Load trials
    data = utils.load(trialsfile)
    trials, U, Z, Z_b, A, P, M, perf, r_p, r_v = data

    # Which network?
    if network == 'p':
        r = r_p
    else:
        r = r_v

    # Number of units
    N = r.shape[-1]

    # Same for every trial
    time  = trials[0]['time']
    Ntime = len(time)

    # Aligned time
    time_a  = np.concatenate((-time[1:][::-1], time))
    Ntime_a = len(time_a)

    #=====================================================================================
    # Aligned to stimulus onset
    #=====================================================================================

    r_by_cond_stimulus   = {}
    n_r_by_cond_stimulus = {}
    for n, trial in enumerate(trials):
        if not perf.decisions[n]:
            continue

        if trial['mod'] == 'va':
            continue
        assert trial['mod'] == 'v' or trial['mod'] == 'a'

        if not perf.corrects[n]:
            continue

        # Condition
        mod    = trial['mod']
        choice = perf.choices[n]
        cond   = (mod, choice)

        # Storage
        r_by_cond_stimulus.setdefault(cond, np.zeros((Ntime_a, N)))
        n_r_by_cond_stimulus.setdefault(cond, np.zeros((Ntime_a, N)))

        # Firing rates
        Mn = np.tile(M[:,n], (N,1)).T
        Rn = r[:,n]*Mn

        # Align point
        t0 = trial['epochs']['stimulus'][0] - 1

        # Before
        n_b = Rn[:t0].shape[0]
        r_by_cond_stimulus[cond][Ntime-1-n_b:Ntime-1]   += Rn[:t0]
        n_r_by_cond_stimulus[cond][Ntime-1-n_b:Ntime-1] += Mn[:t0]

        # After
        n_a = Rn[t0:].shape[0]
        r_by_cond_stimulus[cond][Ntime-1:Ntime-1+n_a]   += Rn[t0:]
        n_r_by_cond_stimulus[cond][Ntime-1:Ntime-1+n_a] += Mn[t0:]

    for cond in r_by_cond_stimulus:
        r_by_cond_stimulus[cond] = utils.div(r_by_cond_stimulus[cond],
                                             n_r_by_cond_stimulus[cond])

    #-------------------------------------------------------------------------------------
    # Plot
    #-------------------------------------------------------------------------------------

    lw     = kwargs.get('lw', 1.5)
    dashes = kwargs.get('dashes', [3, 2])

    vline_props = {'lw': kwargs.get('lw_vline', 0.5)}
    if 'dashes_vline' in kwargs:
        vline_props['linestyle'] = '--'
        vline_props['dashes']    = dashes

    colors_by_mod = {
        'v': Figure.colors('blue'),
        'a': Figure.colors('green')
        }
    linestyle_by_choice = {
        'L': '-',
        'H': '--'
        }
    lineprops = dict(lw=lw)

    def plot_sorted(plot, unit, w, r_sorted):
        t = time_a[w]
        yall = [[1]]
        for cond in [('v', 'H'), ('v', 'L'), ('a', 'H'), ('a', 'L')]:
            mod, choice = cond

            if mod == 'v':
                label = 'Vis, '
            elif mod == 'a':
                label = 'Aud, '
            else:
                raise ValueError(mod)

            if choice == 'H':
                label += 'high'
            elif choice == 'L':
                label += 'low'
            else:
                raise ValueError(choice)

            linestyle = linestyle_by_choice[choice]
            if linestyle == '-':
                lineprops = dict(linestyle=linestyle, lw=lw)
            else:
                lineprops = dict(linestyle=linestyle, lw=lw, dashes=dashes)
            plot.plot(t, r_sorted[cond][w,unit],
                      color=colors_by_mod[mod],
                      label=label,
                      **lineprops)
            yall.append(r_sorted[cond][w,unit])

        return t, yall

    def on_stimulus(plot, unit):
        w, = np.where((time_a >= -300) & (time_a <= 1000))
        t, yall = plot_sorted(plot, unit, w, r_by_cond_stimulus)

        plot.xlim(t[0], t[-1])

        return yall

    if units is not None:
        for plot, unit in zip(plots, units):
            on_stimulus(plot, unit)
    else:
        figspath, name = plots
        for unit in xrange(N):
            fig  = Figure()
            plot = fig.add()

            #-----------------------------------------------------------------------------

            yall = []
            yall += on_stimulus(plot, unit)

            plot.lim('y', yall, lower=0)
            plot.vline(0)

            plot.xlabel('Time (ms)')
            plot.ylabel('Firing rate (a.u.)')

            #-----------------------------------------------------------------------------

            fig.save(path=figspath, name=name+'_{}{:03d}'.format(network, unit))
            fig.close()