Ejemplo n.º 1
0
def nan_invalid_segments(rec):
    """
    Currently a specialized signal for removing incorrect trials from data
    collected using baphy during behavior.

    TODO: Migrate to nems_db or make a more generic version
    """

    # First, select the appropriate subset of data
    rec['resp'] = rec['resp'].rasterize()
    sig = rec['resp']

    # get list of start and stop times (epoch bounds)
    epoch_indices = np.vstack(
        (ep.epoch_intersection(sig.get_epoch_bounds('HIT_TRIAL'),
                               sig.get_epoch_bounds('REFERENCE')),
         ep.epoch_intersection(sig.get_epoch_bounds('REFERENCE'),
                               sig.get_epoch_bounds('PASSIVE_EXPERIMENT'))))

    # Only takes the first of any conflicts (don't think I actually need this)
    epoch_indices = ep.remove_overlap(epoch_indices)

    epoch_indices2 = epoch_indices[0:1, :]
    for i in range(1, epoch_indices.shape[0]):
        if epoch_indices[i, 0] == epoch_indices2[-1, 1]:
            epoch_indices2[-1, 1] = epoch_indices[i, 0]
        else:
            epoch_indices2 = np.concatenate(
                (epoch_indices2, epoch_indices[i:(i + 1), :]), axis=0)

    # add adjusted signals to the recording
    newrec = rec.nan_times(epoch_indices2)

    return newrec
Ejemplo n.º 2
0
def select_times(recording, subset, random_only=True, dual_only=True):
    '''
    Parameters
    ----------
    recording : nems.recording.Recording
        The recording object.
    subset : Nx2 array
        Epochs representing the selected subset (e.g., from an est/val split).
    random_only : bool
        If True, return only the repeating portion of the subset
    dual_only : bool
        If True, return only the dual stream portion of the subset
    '''
    epochs = recording['stim'].epochs

    m_dual = epochs['name'] == 'dual'
    m_repeating = epochs['name'] == 'repeating'
    m_trial = epochs['name'] == 'TRIAL'

    dual_epochs = epochs.loc[m_dual, ['start', 'end']].values
    repeating_epochs = epochs.loc[m_repeating, ['start', 'end']].values
    trial_epochs = epochs.loc[m_trial, ['start', 'end']].values

    if random_only:
        subset = epoch_difference(subset, repeating_epochs)

    if dual_only:
        subset = epoch_intersection(subset, dual_epochs)

    return recording.select_times(subset)
Ejemplo n.º 3
0
def test_intersection_bug():
    # Copied from real data
    a = np.array([
        [57.7, 61.4],
        [61.4, 66.6],
        [66.6, 70.9],
        [70.9, 75.80000000000001],
        [75.8, 81.0],
    ])

    b = np.array([
        [57.7000000000001, 61.4000000000001],
        [61.4000000000001, 66.6000000000001],
        [66.6000000000001, 70.9],
        [75.8000000000001, 81.0000000000001],
        [81.0000000000001, 85.6000000000001],
    ])

    expected = np.array([
        [57.7, 61.4],
        [61.4, 66.6],
        [66.6, 70.9],
        [75.8, 81.0],
    ])

    result = epoch_intersection(a, b)
    print(result)
    assert np.all(result == expected)
Ejemplo n.º 4
0
def test_intersection(epoch_a, epoch_b):
    expected = np.array([
        [60, 70],
        [75, 76],
        [90, 95],
    ])
    result = epoch_intersection(epoch_a, epoch_b)
    assert np.all(result == expected)
Ejemplo n.º 5
0
def test_empty_intersection():
    a = np.array([[0, 20], [50, 70]])

    b = np.array([[25, 30], [30, 45]])

    with pytest.warns(RuntimeWarning):
        result = epoch_intersection(a, b)
        print(result)
Ejemplo n.º 6
0
def test_empty_intersection():
    a = np.array([[0, 20], [50, 70]])

    b = np.array([[25, 30], [30, 45]])

    with pytest.raises(RuntimeWarning,
                       message="Expected RuntimeWarning for size 0"):
        result = epoch_intersection(a, b)
Ejemplo n.º 7
0
def remove_invalid_segments(rec):
    """
    Currently a specialized function for removing incorrect trials from data
    collected using baphy during behavior.

    TODO: Migrate to nems_lbhb or make a more generic version
    """

    # First, select the appropriate subset of data
    rec['resp'] = rec['resp'].rasterize()
    if 'stim' in rec.signals.keys():
        rec['stim'] = rec['stim'].rasterize()

    sig = rec['resp']

    # get list of start and stop indices (epoch bounds)
    epoch_indices = np.vstack(
        (ep.epoch_intersection(sig.get_epoch_indices('REFERENCE'),
                               sig.get_epoch_indices('HIT_TRIAL')),
         ep.epoch_intersection(sig.get_epoch_indices('REFERENCE'),
                               sig.get_epoch_indices('PASSIVE_EXPERIMENT'))))

    # Only takes the first of any conflicts (don't think I actually need this)
    epoch_indices = ep.remove_overlap(epoch_indices)

    # merge any epochs that are directly adjacent
    epoch_indices2 = epoch_indices[0:1]
    for i in range(1, epoch_indices.shape[0]):
        if epoch_indices[i, 0] == epoch_indices2[-1, 1]:
            epoch_indices2[-1, 1] = epoch_indices[i, 1]
        else:
            #epoch_indices2 = np.concatenate(
            #       (epoch_indices2, epoch_indices[i:(i + 1), :]), axis=0)
            epoch_indices2 = np.append(epoch_indices2,
                                       epoch_indices[i:(i + 1)],
                                       axis=0)

    # convert back to times
    epoch_times = epoch_indices2 / sig.fs

    # add adjusted signals to the recording
    newrec = rec.select_times(epoch_times)

    return newrec
Ejemplo n.º 8
0
def make_state_signal(rec,
                      state_signals=['pupil'],
                      permute_signals=[],
                      new_signalname='state'):
    """
    generate state signal for stategain.S/sdexp.S models

    valid state signals include (incomplete list):
        pupil, pupil_ev, pupil_bs, pupil_psd
        active, each_file, each_passive, each_half
        far, hit, lick, p_x_a

    TODO: Migrate to nems_lbhb or make a more generic version
    """

    newrec = rec.copy()
    resp = newrec['resp'].rasterize()

    # normalize mean/std of pupil trace if being used
    if ('pupil' in state_signals) or ('pupil_ev' in state_signals) or \
       ('pupil_bs' in state_signals):
        # normalize min-max
        p = newrec["pupil"].as_continuous().copy()
        # p[p < np.nanmax(p)/5] = np.nanmax(p)/5
        p -= np.nanmean(p)
        p /= np.nanstd(p)
        newrec["pupil"] = newrec["pupil"]._modified_copy(p)

    if ('pupil_psd') in state_signals:
        pup = newrec['pupil'].as_continuous().copy()
        fs = newrec['pupil'].fs
        # get spectrogram of pupil
        nperseg = int(60 * fs)
        noverlap = nperseg - 1
        f, time, Sxx = ss.spectrogram(pup.squeeze(),
                                      fs=fs,
                                      nperseg=nperseg,
                                      noverlap=noverlap)
        max_chan = 4  # (np.abs(f - 0.1)).argmin()
        # Keep only first five channels of spectrogram
        #f = interpolate.interp1d(np.arange(0, Sxx.shape[1]), Sxx[:max_chan, :], axis=1)
        #newspec = f(np.linspace(0, Sxx.shape[-1]-1, pup.shape[-1]))
        pad1 = np.ones((max_chan, int(nperseg / 2))) * Sxx[:max_chan, [0]]
        pad2 = np.ones((max_chan, int(nperseg / 2 - 1))) * Sxx[:max_chan, [-1]]
        newspec = np.concatenate((pad1, Sxx[:max_chan, :], pad2), axis=1)

        # = np.concatenate((Sxx[:max_chan, :], np.tile(Sxx[:max_chan,-1][:, np.newaxis], [1, noverlap])), axis=1)
        newspec -= np.nanmean(newspec, axis=1, keepdims=True)
        newspec /= np.nanstd(newspec, axis=1, keepdims=True)

        spec_signal = newrec['pupil']._modified_copy(newspec)
        spec_signal.name = 'pupil_psd'
        chan_names = []
        for chan in range(0, newspec.shape[0]):
            chan_names.append('puppsd{0}'.format(chan))
        spec_signal.chans = chan_names

        newrec.add_signal(spec_signal)

    if ('pupil_ev' in state_signals) or ('pupil_bs' in state_signals):
        # generate separate pupil baseline and evoked signals

        prestimsilence = newrec["pupil"].extract_epoch('PreStimSilence')
        spont_bins = prestimsilence.shape[2]
        pupil_trial = newrec["pupil"].extract_epoch('TRIAL')

        pupil_bs = np.zeros(pupil_trial.shape)
        for ii in range(pupil_trial.shape[0]):
            pupil_bs[ii, :, :] = np.mean(pupil_trial[ii, :, :spont_bins])
        pupil_ev = pupil_trial - pupil_bs

        newrec['pupil_ev'] = newrec["pupil"].replace_epoch('TRIAL', pupil_ev)
        newrec['pupil_ev'].chans = ['pupil_ev']
        newrec['pupil_bs'] = newrec["pupil"].replace_epoch('TRIAL', pupil_bs)
        newrec['pupil_bs'].chans = ['pupil_bs']

    if ('each_passive' in state_signals):
        file_epochs = ep.epoch_names_matching(resp.epochs, "^FILE_")
        pset = []
        found_passive1 = False
        for f in file_epochs:
            # test if passive expt
            epoch_indices = ep.epoch_intersection(
                resp.get_epoch_indices(f),
                resp.get_epoch_indices('PASSIVE_EXPERIMENT'))
            if epoch_indices.size:
                if not (found_passive1):
                    # skip first passive
                    found_passive1 = True
                else:
                    pset.append(f)
                    newrec[f] = resp.epoch_to_signal(f)
        state_signals.remove('each_passive')
        state_signals.extend(pset)
        if 'each_passive' in permute_signals:
            permute_signals.remove('each_passive')
            permute_signals.extend(pset)

    if ('each_file' in state_signals):
        file_epochs = ep.epoch_names_matching(resp.epochs, "^FILE_")
        trial_indices = resp.get_epoch_indices('TRIAL')
        passive_indices = resp.get_epoch_indices('PASSIVE_EXPERIMENT')
        pset = []
        pcount = 0
        acount = 0
        for f in file_epochs:
            # test if passive expt
            f_indices = resp.get_epoch_indices(f)
            epoch_indices = ep.epoch_intersection(f_indices, passive_indices)

            if epoch_indices.size:
                # this is a passive file
                name1 = "PASSIVE_{}".format(pcount)
                pcount += 1
                if pcount == 1:
                    acount = 1  # reset acount for actives after first passive
                else:
                    # use first passive part A as baseline - don't model
                    pset.append(name1)
                    newrec[name1] = resp.epoch_to_signal(name1,
                                                         indices=f_indices)

            else:
                name1 = "ACTIVE_{}".format(acount)
                pset.append(name1)
                newrec[name1] = resp.epoch_to_signal(name1, indices=f_indices)

                if pcount == 0:
                    acount -= 1
                else:
                    acount += 1

            # test if passive expt


#            epoch_indices = ep.epoch_intersection(
#                    resp.get_epoch_indices(f),
#                    resp.get_epoch_indices('PASSIVE_EXPERIMENT'))
#            if epoch_indices.size and not(found_passive1):
#                # skip first passive
#                found_passive1 = True
#            else:
#                pset.append(f)
#                newrec[f] = resp.epoch_to_signal(f)
        state_signals.remove('each_file')
        state_signals.extend(pset)
        if 'each_file' in permute_signals:
            permute_signals.remove('each_file')
            permute_signals.extend(pset)

    if ('each_half' in state_signals):
        file_epochs = ep.epoch_names_matching(resp.epochs, "^FILE_")
        trial_indices = resp.get_epoch_indices('TRIAL')
        passive_indices = resp.get_epoch_indices('PASSIVE_EXPERIMENT')
        pset = []
        pcount = 0
        acount = 0
        for f in file_epochs:
            # test if passive expt
            f_indices = resp.get_epoch_indices(f)
            epoch_indices = ep.epoch_intersection(f_indices, passive_indices)
            trial_intersect = ep.epoch_intersection(f_indices, trial_indices)
            #trial_count = trial_intersect.shape[0]
            #_split = int(trial_count/2)
            _t1 = trial_intersect[0, 0]
            _t2 = trial_intersect[-1, 1]
            _split = int((_t1 + _t2) / 2)
            epoch1 = np.array([[_t1, _split]])
            epoch2 = np.array([[_split, _t2]])

            if epoch_indices.size:
                # this is a passive file
                name1 = "PASSIVE_{}_{}".format(pcount, 'A')
                name2 = "PASSIVE_{}_{}".format(pcount, 'B')
                pcount += 1
                if pcount == 1:
                    acount = 1  # reset acount for actives after first passive
                else:
                    # don't model PASSIVE_0 A -- baseline
                    pset.append(name1)
                    newrec[name1] = resp.epoch_to_signal(name1, indices=epoch1)

                # do include part B
                pset.append(name2)
                newrec[name2] = resp.epoch_to_signal(name2, indices=epoch2)
            else:
                name1 = "ACTIVE_{}_{}".format(acount, 'A')
                name2 = "ACTIVE_{}_{}".format(acount, 'B')
                pset.append(name1)
                newrec[name1] = resp.epoch_to_signal(name1, indices=epoch1)
                pset.append(name2)
                newrec[name2] = resp.epoch_to_signal(name2, indices=epoch2)

                if pcount == 0:
                    acount -= 1
                else:
                    acount += 1

        state_signals.remove('each_half')
        state_signals.extend(pset)
        if 'each_half' in permute_signals:
            permute_signals.remove('each_half')
            permute_signals.extend(pset)

    # generate task state signals
    if 'pas' in state_signals:
        fpre = (resp.epochs['name'] == "PRE_PASSIVE")
        fpost = (resp.epochs['name'] == "POST_PASSIVE")
        INCLUDE_PRE_POST = (np.sum(fpre) > 0) & (np.sum(fpost) > 0)
        if INCLUDE_PRE_POST:
            # only include pre-passive if post-passive also exists
            # otherwise the regression gets screwed up
            newrec['pre_passive'] = resp.epoch_to_signal('PRE_PASSIVE')
        else:
            # place-holder, all zeros
            newrec['pre_passive'] = resp.epoch_to_signal('XXX')
            newrec['pre_passive'].chans = ['PRE_PASSIVE']
    if 'puretone_trials' in state_signals:
        newrec['puretone_trials'] = resp.epoch_to_signal('PURETONE_BEHAVIOR')
        newrec['puretone_trials'].chans = ['puretone_trials']
    if 'easy_trials' in state_signals:
        newrec['easy_trials'] = resp.epoch_to_signal('EASY_BEHAVIOR')
        newrec['easy_trials'].chans = ['easy_trials']
    if 'hard_trials' in state_signals:
        newrec['hard_trials'] = resp.epoch_to_signal('HARD_BEHAVIOR')
        newrec['hard_trials'].chans = ['hard_trials']
    if ('active' in state_signals) or ('far' in state_signals):
        newrec['active'] = resp.epoch_to_signal('ACTIVE_EXPERIMENT')
        newrec['active'].chans = ['active']
    if (('hit_trials' in state_signals) or ('miss_trials' in state_signals)
            or ('far' in state_signals) or ('hit' in state_signals)):
        newrec['hit_trials'] = resp.epoch_to_signal('HIT_TRIAL')
        newrec['miss_trials'] = resp.epoch_to_signal('MISS_TRIAL')
        newrec['fa_trials'] = resp.epoch_to_signal('FA_TRIAL')

    sm_len = 180 * newrec['resp'].fs
    if 'far' in state_signals:
        a = newrec['active'].as_continuous()
        fa = newrec['fa_trials'].as_continuous().astype(float)
        #c = np.concatenate((np.zeros((1,sm_len)), np.ones((1,sm_len+1))),
        #                   axis=1)
        c = np.ones((1, sm_len)) / sm_len

        fa = convolve2d(fa, c, mode='same')
        fa[a] -= 0.25  # np.nanmean(fa[a])
        fa[np.logical_not(a)] = 0

        s = newrec['fa_trials']._modified_copy(fa)
        s.chans = ['far']
        s.name = 'far'
        newrec.add_signal(s)

    if 'hit' in state_signals:
        a = newrec['active'].as_continuous()
        hr = newrec['hit_trials'].as_continuous().astype(float)
        ms = newrec['miss_trials'].as_continuous().astype(float)
        ht = hr - ms

        c = np.ones((1, sm_len)) / sm_len

        ht = convolve2d(ht, c, mode='same')
        ht[a] -= 0.1  # np.nanmean(ht[a])
        ht[np.logical_not(a)] = 0

        s = newrec['hit_trials']._modified_copy(ht)
        s.chans = ['hit']
        s.name = 'hit'
        newrec.add_signal(s)

    if 'lick' in state_signals:
        newrec['lick'] = resp.epoch_to_signal('LICK')

    # pupil interactions
    if ('p_x_a' in state_signals):
        # normalize min-max
        p = newrec["pupil"].as_continuous()
        a = newrec["active"].as_continuous()
        newrec["p_x_a"] = newrec["pupil"]._modified_copy(p * a)
        newrec["p_x_a"].chans = ["p_x_a"]

    if ('prw' in state_signals):
        # add channel two of the resp to state and delete it from resp
        if len(rec['resp'].chans) != 2:
            raise ValueError("this is for pairwise fitting")
        else:
            ch2 = rec['resp'].chans[1]
            ch1 = rec['resp'].chans[0]

        newrec['prw'] = newrec['resp'].extract_channels([ch2]).rasterize()
        newrec['resp'] = newrec['resp'].extract_channels([ch1]).rasterize()

    if ('pup_x_prw' in state_signals):
        # interaction term between pupil and the other cell
        if 'prw' not in newrec.signals.keys():
            raise ValueError("Must include prw alone before using interaction")

        else:
            pup = newrec['pupil']._data
            prw = newrec['prw']._data
            sig = newrec['pupil']._modified_copy(pup * prw)
            sig.name = 'pup_x_prw'
            sig.chans = ['pup_x_prw']
            newrec.add_signal(sig)

    for i, x in enumerate(state_signals):
        if x in permute_signals:
            # kludge: fix random seed to index of state signal in list
            # this avoids using the same seed for each shuffled signal
            # but also makes shuffling reproducible
            newrec = concatenate_state_channel(
                newrec,
                newrec[x].shuffle_time(rand_seed=i, mask=newrec['mask']),
                state_signal_name=new_signalname)
        else:
            newrec = concatenate_state_channel(
                newrec, newrec[x], state_signal_name=new_signalname)

        newrec = concatenate_state_channel(newrec,
                                           newrec[x],
                                           state_signal_name=new_signalname +
                                           "_raw")

    return newrec
Ejemplo n.º 9
0
def mask_all_but_correct_references(rec,
                                    balance_rep_count=False,
                                    include_incorrect=False):
    """
    Specialized function for removing incorrect trials from data
    collected using baphy during behavior.

    TODO: Migrate to nems_lbhb and/or make a more generic version
    """

    newrec = rec.copy()
    newrec['resp'] = newrec['resp'].rasterize()
    if 'stim' in newrec.signals.keys():
        newrec['stim'] = newrec['stim'].rasterize()
    resp = newrec['resp']

    if balance_rep_count:

        epoch_regex = "^STIM_"
        epochs_to_extract = ep.epoch_names_matching(resp.epochs, epoch_regex)
        p = resp.get_epoch_indices("PASSIVE_EXPERIMENT")
        a = resp.get_epoch_indices("HIT_TRIAL")

        epoch_list = []
        for s in epochs_to_extract:
            e = resp.get_epoch_indices(s)
            pe = ep.epoch_intersection(e, p)
            ae = ep.epoch_intersection(e, a)
            if len(pe) > len(ae):
                epoch_list.extend(ae)
                subset = np.round(np.linspace(0, len(pe),
                                              len(ae) + 1)).astype(int)
                for i in subset[:-1]:
                    epoch_list.append(pe[i])
            else:
                subset = np.round(np.linspace(0, len(ae),
                                              len(pe) + 1)).astype(int)
                for i in subset[:-1]:
                    epoch_list.append(ae[i])
                epoch_list.extend(pe)

        newrec = newrec.create_mask(epoch_list)

    elif include_incorrect:
        log.info('INCLUDING ALL TRIALS (CORRECT AND INCORRECT)')
        newrec = newrec.and_mask(['REFERENCE'])

    else:
        newrec = newrec.and_mask(['PASSIVE_EXPERIMENT', 'HIT_TRIAL'])
        newrec = newrec.and_mask(['REFERENCE'])

    # figure out if some actives should be masked out


#    t = ep.epoch_names_matching(resp.epochs, "^TAR_")
#    tm = [tt[:-2] for tt in t]  # trim last digits
#    active_epochs = resp.get_epoch_indices("ACTIVE_EXPERIMENT")
#    if len(set(tm)) > 1 and len(active_epochs) > 1:
#        print('Multiple targets: ', tm)
#        files = ep.epoch_names_matching(resp.epochs, "^FILE_")
#        keep_files = files
#        e = active_epochs[1]
#        for i,f in enumerate(files):
#            fi = resp.get_epoch_indices(f)
#            if any(ep.epoch_contains([e], fi, 'both')):
#                keep_files = files[:i]
#
#        print('Print keeping files: ', keep_files)
#        newrec = newrec.and_mask(keep_files)

    if 'state' in newrec.signals:
        b_states = [
            'far', 'hit', 'lick', 'puretone_trials', 'easy_trials',
            'hard_trials'
        ]
        trec = newrec.copy()
        trec = trec.and_mask(['ACTIVE_EXPERIMENT'])
        st = trec['state'].as_continuous().copy()
        str = trec['state_raw'].as_continuous().copy()
        mask = trec['mask'].as_continuous()[0, :]
        for s in trec['state'].chans:
            if s in b_states:
                i = trec['state'].chans.index(s)
                m = np.nanmean(st[i, mask])
                sd = np.nanstd(st[i, mask])
                # print("{} {}: m={}, std={}".format(s, i, m, sd))
                # print(np.sum(mask))
                st[i, mask] -= m
                st[i, mask] /= sd
                str[i, mask] -= m
                str[i, mask] /= sd
        newrec['state'] = newrec['state']._modified_copy(st)
        newrec['state_raw'] = newrec['state_raw']._modified_copy(str)

    return newrec
Ejemplo n.º 10
0
def ev_pupil(cellid, batch, presilence=0.35):
    modelname = "psth.fs20.pup-st.pup.beh_stategain.3_init.st-basic"

    print('Finding recording for cell/batch {0}/{1}...'.format(cellid, batch))

    # parse modelname
    kws = modelname.split("_")
    loader = kws[0]
    modelspecname = "_".join(kws[1:-1])
    fitkey = kws[-1]

    # generate xfspec, which defines sequence of events to load data,
    # generate modelspec, fit data, plot results and save
    recording_uri = nw.generate_recording_uri(cellid, batch, loader)
    print(recording_uri)

    rec = recording.load_recording(recording_uri)

    pupil = rec['pupil']
    pshift = np.int(pupil.fs * 0.75)
    #pupil = pupil._modified_copy(np.roll(pupil._data, (0, pshift)))
    d = pupil._data
    pupil = pupil._modified_copy(np.roll(d / np.nanmax(d), (0, pshift)))
    #pupil = pupil._modified_copy(d / np.nanmax(d))
    #pupil = pupil._modified_copy(np.roll(d, (0, pshift)))

    trials = pupil.get_epoch_indices('TRIAL')
    targets = pupil.get_epoch_indices('TARGET')
    pt_blocks = pupil.get_epoch_indices('PURETONE_BEHAVIOR').tolist()
    easy_blocks = pupil.get_epoch_indices('EASY_BEHAVIOR').tolist()
    hard_blocks = pupil.get_epoch_indices('HARD_BEHAVIOR').tolist()
    passive_blocks = pupil.get_epoch_indices('PASSIVE_EXPERIMENT').tolist()
    behavior_blocks = pupil.get_epoch_indices('ACTIVE_EXPERIMENT')

    blocks = []
    for p in passive_blocks:
        p.append('passive')
        blocks.append(p)
    for p in pt_blocks:
        p.append('puretone')
        blocks.append(p)
    for p in easy_blocks:
        p.append('easy')
        blocks.append(p)
    for p in hard_blocks:
        p.append('hard')
        blocks.append(p)

    blocks.sort()
    #print(blocks)
    trialbins = int(pupil.fs * 6)
    prebins = int(pupil.fs * presilence)

    ev = []
    ev_prenorm = []
    label = []
    beh_lickrate = []
    beh_lickrate_norm = []
    beh_all = {}
    for block in blocks:
        k = block[-1]
        label.append(k)
        block_trials = ep.epoch_intersection(trials, np.array([block[0:2]]))
        tcount = block_trials.shape[0]
        for t in range(tcount):
            block_trials[t, 1] = block_trials[t, 0] + trialbins
            if block_trials[t, 1] > pupil.shape[1]:
                block_trials[t, 1] = pupil.shape[1]

        tev = pupil.extract_epoch(block_trials)[:, 0, :]

        tev0 = np.nanmean(tev[:, :prebins], axis=1)
        #        m = tev0 > 0.3 * np.nanmax(tev0)
        #        print(block)
        #        print("{}-{} mean {} ({}/{} big)".format(
        #                np.nanmin(tev0),np.nanmax(tev0),np.nanmean(tev0),
        #                np.sum(m), len(tev0)))
        #        tev = tev[m, :]

        ev.append(tev)
        ev_prenorm.append(tev -
                          np.mean(tev[:, :prebins], axis=1, keepdims=True))
        if k not in beh_all.keys():
            beh_all[k] = np.array([])
        beh_all[k] = np.append(beh_all[k], tev.ravel())
        beh_lickrate.append((k, np.nanmean(ev[-1], axis=0)))
        beh_lickrate_norm.append((k, np.nanmean(ev_prenorm[-1], axis=0)))

        #print("{}: {} trials, {} bins".format(k,tev.shape[0],tev.shape[1]))

    perf_blocks = {
        'hits': pupil.get_epoch_indices('HIT_TRIAL'),
        'misses': pupil.get_epoch_indices('MISS_TRIAL'),
        'fas': pupil.get_epoch_indices('FA_TRIAL')
    }

    ev = []
    ev_prenorm = []
    perf_lickrate = []
    perf_lickrate_norm = []
    perf_all = {}
    for k, block in perf_blocks.items():

        block_trials = ep.epoch_intersection(trials, block)
        tcount = block_trials.shape[0]
        for t in range(tcount):
            block_trials[t, 1] = block_trials[t, 0] + trialbins
            if block_trials[t, 1] > pupil.shape[1]:
                block_trials[t, 1] = pupil.shape[1]
        t = pupil.extract_epoch(block_trials, allow_empty=True)
        if t.size:
            tev = t[:, 0, :]
        else:
            tev = np.ones((1, trialbins)) * np.nan
        perf_all[k] = t.ravel()

        ev.append(tev)
        ev_prenorm.append(tev -
                          np.mean(tev[:, :prebins], axis=1, keepdims=True))

        perf_lickrate.append((k, np.nanmean(ev[-1], axis=0)))
        perf_lickrate_norm.append((k, np.nanmean(ev_prenorm[-1], axis=0)))

        #print("{}: {} trials, {} bins".format(k,tev.shape[0],tev.shape[1]))

    return beh_lickrate, beh_lickrate_norm, beh_all, \
           perf_lickrate, perf_lickrate_norm, perf_all
Ejemplo n.º 11
0
def make_state_signal(rec,
                      state_signals=['pupil'],
                      permute_signals=[],
                      new_signalname='state'):
    """
    generate state signal for stategainX models

    TODO: Migrate to nems_lbhb or make a more generic version
    """

    newrec = rec.copy()
    resp = newrec['resp'].rasterize()

    # normalize mean/std of pupil trace if being used
    if ('pupil' in state_signals) or ('pupil_ev' in state_signals) or \
       ('pupil_bs' in state_signals):
        # normalize min-max
        p = newrec["pupil"].as_continuous().copy()
        # p[p < np.nanmax(p)/5] = np.nanmax(p)/5
        p -= np.nanmean(p)
        p /= np.nanstd(p)
        newrec["pupil"] = newrec["pupil"]._modified_copy(p)

    if ('pupil_psd') in state_signals:
        pup = newrec['pupil'].as_continuous().copy()
        fs = newrec['pupil'].fs
        # get spectrogram of pupil
        nperseg = int(60 * fs)
        noverlap = nperseg - 1
        f, time, Sxx = ss.spectrogram(pup.squeeze(),
                                      fs=fs,
                                      nperseg=nperseg,
                                      noverlap=noverlap)
        max_chan = 4  #(np.abs(f - 0.1)).argmin()
        # Keep only first five channels of spectrogram
        #f = interpolate.interp1d(np.arange(0, Sxx.shape[1]), Sxx[:max_chan, :], axis=1)
        #newspec = f(np.linspace(0, Sxx.shape[-1]-1, pup.shape[-1]))
        pad1 = np.ones((max_chan, int(nperseg / 2))) * Sxx[:max_chan, [0]]
        pad2 = np.ones((max_chan, int(nperseg / 2 - 1))) * Sxx[:max_chan, [-1]]
        newspec = np.concatenate((pad1, Sxx[:max_chan, :], pad2), axis=1)

        # = np.concatenate((Sxx[:max_chan, :], np.tile(Sxx[:max_chan,-1][:, np.newaxis], [1, noverlap])), axis=1)
        newspec -= np.nanmean(newspec, axis=1, keepdims=True)
        newspec /= np.nanstd(newspec, axis=1, keepdims=True)

        spec_signal = newrec['pupil']._modified_copy(newspec)
        spec_signal.name = 'pupil_psd'
        chan_names = []
        for chan in range(0, newspec.shape[0]):
            chan_names.append('puppsd{0}'.format(chan))
        spec_signal.chans = chan_names

        newrec.add_signal(spec_signal)

    if ('pupil_ev' in state_signals) or ('pupil_bs' in state_signals):
        # generate separate pupil baseline and evoked signals

        prestimsilence = newrec["pupil"].extract_epoch('PreStimSilence')
        spont_bins = prestimsilence.shape[2]
        pupil_trial = newrec["pupil"].extract_epoch('TRIAL')

        pupil_bs = np.zeros(pupil_trial.shape)
        for ii in range(pupil_trial.shape[0]):
            pupil_bs[ii, :, :] = np.mean(pupil_trial[ii, :, :spont_bins])
        pupil_ev = pupil_trial - pupil_bs

        newrec['pupil_ev'] = newrec["pupil"].replace_epoch('TRIAL', pupil_ev)
        newrec['pupil_ev'].chans = ['pupil_ev']
        newrec['pupil_bs'] = newrec["pupil"].replace_epoch('TRIAL', pupil_bs)
        newrec['pupil_bs'].chans = ['pupil_bs']

    if ('each_passive' in state_signals):
        file_epochs = ep.epoch_names_matching(resp.epochs, "^FILE_")
        pset = []
        for f in file_epochs:
            epoch_indices = ep.epoch_intersection(
                resp.get_epoch_indices(f),
                resp.get_epoch_indices('PASSIVE_EXPERIMENT'))
            if epoch_indices.size:
                pset.append(f)
                newrec[f] = resp.epoch_to_signal(f)
        state_signals.remove('each_passive')
        state_signals.extend(pset)
        if 'each_passive' in permute_signals:
            permute_signals.remove('each_passive')
            permute_signals.extend(pset)

    # generate task state signals
    fpre = (resp.epochs['name'] == "PRE_PASSIVE")
    fpost = (resp.epochs['name'] == "POST_PASSIVE")
    INCLUDE_PRE_POST = (np.sum(fpre) > 0) & (np.sum(fpost) > 0)
    if INCLUDE_PRE_POST:
        # only include pre-passive if post-passive also exists
        # otherwise the regression gets screwed up
        newrec['pre_passive'] = resp.epoch_to_signal('PRE_PASSIVE')
    else:
        # place-holder, all zeros
        newrec['pre_passive'] = resp.epoch_to_signal('XXX')
        newrec['pre_passive'].chans = ['PRE_PASSIVE']

    newrec['hit_trials'] = resp.epoch_to_signal('HIT_TRIAL')
    newrec['miss_trials'] = resp.epoch_to_signal('MISS_TRIAL')
    newrec['fa_trials'] = resp.epoch_to_signal('FA_TRIAL')
    newrec['puretone_trials'] = resp.epoch_to_signal('PURETONE_BEHAVIOR')
    newrec['puretone_trials'].chans = ['puretone_trials']
    newrec['easy_trials'] = resp.epoch_to_signal('EASY_BEHAVIOR')
    newrec['easy_trials'].chans = ['easy_trials']
    newrec['hard_trials'] = resp.epoch_to_signal('HARD_BEHAVIOR')
    newrec['hard_trials'].chans = ['hard_trials']
    newrec['active'] = resp.epoch_to_signal('ACTIVE_EXPERIMENT')
    newrec['active'].chans = ['active']
    if 'lick' in state_signals:
        newrec['lick'] = resp.epoch_to_signal('LICK')

    # pupil interactions
    if ('pupil' in state_signals):
        # normalize min-max
        p = newrec["pupil"].as_continuous()
        a = newrec["active"].as_continuous()
        newrec["p_x_a"] = newrec["pupil"]._modified_copy(p * a)
        newrec["p_x_a"].chans = ["p_x_a"]

    for i, x in enumerate(state_signals):
        if x in permute_signals:
            # kludge: fix random seed to index of state signal in list
            # this avoids using the same seed for each shuffled signal
            # but also makes shuffling reproducible
            newrec = concatenate_state_channel(
                newrec,
                newrec[x].shuffle_time(rand_seed=i),
                state_signal_name=new_signalname)
        else:
            newrec = concatenate_state_channel(
                newrec, newrec[x], state_signal_name=new_signalname)

    return newrec
Ejemplo n.º 12
0
def _model_step_plot(cellid, batch, modelnames, factors, state_colors=None):
    """
    state_colors : N x 2 list
       color spec for high/low lines in each of the N states
    """
    global line_colors
    global fill_colors

    modelname_p0b0, modelname_p0b, modelname_pb0, modelname_pb = \
       modelnames
    factor0, factor1, factor2 = factors

    xf_p0b0, ctx_p0b0 = xhelp.load_model_xform(cellid,
                                               batch,
                                               modelname_p0b0,
                                               eval_model=False)
    # ctx_p0b0, l = xforms.evaluate(xf_p0b0, ctx_p0b0, stop=-2)

    ctx_p0b0, l = xforms.evaluate(xf_p0b0, ctx_p0b0, start=0, stop=-2)

    xf_p0b, ctx_p0b = xhelp.load_model_xform(cellid,
                                             batch,
                                             modelname_p0b,
                                             eval_model=False)
    ctx_p0b, l = xforms.evaluate(xf_p0b, ctx_p0b, start=0, stop=-2)

    xf_pb0, ctx_pb0 = xhelp.load_model_xform(cellid,
                                             batch,
                                             modelname_pb0,
                                             eval_model=False)
    #ctx_pb0['rec'] = ctx_p0b0['rec'].copy()
    ctx_pb0, l = xforms.evaluate(xf_pb0, ctx_pb0, start=0, stop=-2)

    xf_pb, ctx_pb = xhelp.load_model_xform(cellid,
                                           batch,
                                           modelname_pb,
                                           eval_model=False)
    #ctx_pb['rec'] = ctx_p0b0['rec'].copy()
    ctx_pb, l = xforms.evaluate(xf_pb, ctx_pb, start=0, stop=-2)

    # organize predictions by different models
    val = ctx_pb['val'][0].copy()

    # val['pred_p0b0'] = ctx_p0b0['val'][0]['pred'].copy()
    val['pred_p0b'] = ctx_p0b['val'][0]['pred'].copy()
    val['pred_pb0'] = ctx_pb0['val'][0]['pred'].copy()

    state_var_list = val['state'].chans

    pred_mod = np.zeros([len(state_var_list), 2])
    pred_mod_full = np.zeros([len(state_var_list), 2])
    resp_mod_full = np.zeros([len(state_var_list), 1])

    state_std = np.nanstd(val['state'].as_continuous(), axis=1, keepdims=True)
    for i, var in enumerate(state_var_list):
        if state_std[i]:
            # actual response modulation index for each state var
            resp_mod_full[i] = state_mod_index(val,
                                               epoch='REFERENCE',
                                               psth_name='resp',
                                               state_chan=var)

            mod2_p0b = state_mod_index(val,
                                       epoch='REFERENCE',
                                       psth_name='pred_p0b',
                                       state_chan=var)
            mod2_pb0 = state_mod_index(val,
                                       epoch='REFERENCE',
                                       psth_name='pred_pb0',
                                       state_chan=var)
            mod2_pb = state_mod_index(val,
                                      epoch='REFERENCE',
                                      psth_name='pred',
                                      state_chan=var)

            pred_mod[i] = np.array([mod2_pb - mod2_p0b, mod2_pb - mod2_pb0])
            pred_mod_full[i] = np.array([mod2_pb0, mod2_p0b])

    pred_mod_norm = pred_mod / (state_std + (state_std == 0).astype(float))
    pred_mod_full_norm = pred_mod_full / (state_std +
                                          (state_std == 0).astype(float))

    if 'each_passive' in factors:
        psth_names_ctl = ["pred_p0b"]
        factors.remove('each_passive')
        for v in state_var_list:
            if v.startswith('FILE_'):
                factors.append(v)
                psth_names_ctl.append("pred_pb0")
    else:
        psth_names_ctl = ["pred_p0b", "pred_pb0"]

    col_count = len(factors) - 1
    if state_colors is None:
        state_colors = [[None, None]] * col_count

    fh = plt.figure(figsize=(8, 8))
    ax = plt.subplot(4, 1, 1)
    nplt.state_vars_timeseries(val,
                               ctx_pb['modelspecs'][0],
                               state_colors=[s[1] for s in state_colors])
    ax.set_title("{}/{} - {}".format(cellid, batch, modelname_pb))
    ax.set_ylabel("{} r={:.3f}".format(
        factor0, ctx_p0b0['modelspecs'][0][0]['meta']['r_test'][0]))
    nplt.ax_remove_box(ax)

    for i, var in enumerate(factors[1:]):
        if var.startswith('FILE_'):
            varlbl = var[5:]
        else:
            varlbl = var
        ax = plt.subplot(4, col_count, col_count + i + 1)

        nplt.state_var_psth_from_epoch(val,
                                       epoch="REFERENCE",
                                       psth_name="resp",
                                       psth_name2=psth_names_ctl[i],
                                       state_chan=var,
                                       ax=ax,
                                       colors=state_colors[i])
        if i == 0:
            ax.set_ylabel("Control model")
            ax.set_title("{} ctl r={:.3f}".format(
                varlbl.lower(),
                ctx_p0b['modelspecs'][0][0]['meta']['r_test'][0]),
                         fontsize=6)
        else:
            ax.yaxis.label.set_visible(False)
            ax.set_title("{} ctl r={:.3f}".format(
                varlbl.lower(),
                ctx_pb0['modelspecs'][0][0]['meta']['r_test'][0]),
                         fontsize=6)
        if ax.legend_:
            ax.legend_.remove()
        ax.xaxis.label.set_visible(False)
        nplt.ax_remove_box(ax)

        ax = plt.subplot(4, col_count, col_count * 2 + i + 1)
        nplt.state_var_psth_from_epoch(val,
                                       epoch="REFERENCE",
                                       psth_name="resp",
                                       psth_name2="pred",
                                       state_chan=var,
                                       ax=ax,
                                       colors=state_colors[i])
        if i == 0:
            ax.set_ylabel("Full Model")
        else:
            ax.yaxis.label.set_visible(False)
        if ax.legend_:
            ax.legend_.remove()

        if psth_names_ctl[i] == "pred_p0b":
            j = 0
        else:
            j = 1

        ax.set_title("r={:.3f} rawmod={:.3f} umod={:.3f}".format(
            ctx_pb['modelspecs'][0][0]['meta']['r_test'][0],
            pred_mod_full_norm[i + 1][j], pred_mod_norm[i + 1][j]),
                     fontsize=6)

        if var == 'active':
            ax.legend(('pas', 'act'))
        elif var == 'pupil':
            ax.legend(('small', 'large'))
        elif var == 'PRE_PASSIVE':
            ax.legend(('act+post', 'pre'))
        elif var.startswith('FILE_'):
            ax.legend(('this', 'others'))
        nplt.ax_remove_box(ax)

    # EXTRA PANELS
    # figure out some basic aspects of tuning/selectivity for target vs.
    # reference:
    r = ctx_pb['rec']['resp']
    e = r.epochs
    fs = r.fs

    passive_epochs = r.get_epoch_indices("PASSIVE_EXPERIMENT")
    tar_names = ep.epoch_names_matching(e, "^TAR_")
    tar_resp = {}
    for tarname in tar_names:
        t = r.get_epoch_indices(tarname)
        t = ep.epoch_intersection(t, passive_epochs)
        tar_resp[tarname] = r.extract_epoch(t) * fs

    # only plot tar responses with max SNR or probe SNR
    keys = []
    for k in list(tar_resp.keys()):
        if k.endswith('0') | k.endswith('2'):
            keys.append(k)
    keys.sort()

    # assume the reference with most reps is the one overlapping the target
    groups = ep.group_epochs_by_occurrence_counts(e, '^STIM_')
    l = np.array(list(groups.keys()))
    hi = np.max(l)
    ref_name = groups[hi][0]
    t = r.get_epoch_indices(ref_name)
    t = ep.epoch_intersection(t, passive_epochs)
    ref_resp = r.extract_epoch(t) * fs

    t = r.get_epoch_indices('REFERENCE')
    t = ep.epoch_intersection(t, passive_epochs)
    all_ref_resp = r.extract_epoch(t) * fs

    prestimsilence = r.get_epoch_indices('PreStimSilence')
    prebins = prestimsilence[0, 1] - prestimsilence[0, 0]
    poststimsilence = r.get_epoch_indices('PostStimSilence')
    postbins = poststimsilence[0, 1] - poststimsilence[0, 0]
    durbins = ref_resp.shape[-1] - prebins

    spont = np.nanmean(all_ref_resp[:, 0, :prebins])
    ref_mean = np.nanmean(ref_resp[:, 0, prebins:durbins]) - spont
    all_ref_mean = np.nanmean(all_ref_resp[:, 0, prebins:durbins]) - spont
    #print(spont)
    #print(np.nanmean(ref_resp[:,0,prebins:-postbins]))
    ax1 = plt.subplot(4, 2, 7)
    ref_psth = [
        np.nanmean(ref_resp[:, 0, :], axis=0),
        np.nanmean(all_ref_resp[:, 0, :], axis=0)
    ]
    ll = [
        "{} {:.1f}".format(ref_name, ref_mean),
        "all refs {:.1f}".format(all_ref_mean)
    ]
    nplt.timeseries_from_vectors(ref_psth,
                                 fs=fs,
                                 legend=ll,
                                 ax=ax1,
                                 time_offset=prebins / fs)

    ax2 = plt.subplot(4, 2, 8)
    ll = []
    tar_mean = np.zeros(np.max([2, len(keys)])) * np.nan
    tar_psth = []
    for ii, k in enumerate(keys):
        tar_psth.append(np.nanmean(tar_resp[k][:, 0, :], axis=0))
        tar_mean[ii] = np.nanmean(tar_resp[k][:, 0, prebins:durbins]) - spont
        ll.append("{} {:.1f}".format(k, tar_mean[ii]))
    nplt.timeseries_from_vectors(tar_psth,
                                 fs=fs,
                                 legend=ll,
                                 ax=ax2,
                                 time_offset=prebins / fs)
    # plt.legend(ll, fontsize=6)

    ymin = np.min([ax1.get_ylim()[0], ax2.get_ylim()[0]])
    ymax = np.max([ax1.get_ylim()[1], ax2.get_ylim()[1]])
    ax1.set_ylim([ymin, ymax])
    ax2.set_ylim([ymin, ymax])
    nplt.ax_remove_box(ax1)
    nplt.ax_remove_box(ax2)

    plt.tight_layout()

    stats = {
        'cellid':
        cellid,
        'batch':
        batch,
        'modelnames':
        modelnames,
        'state_vars':
        state_var_list,
        'factors':
        factors,
        'r_test':
        np.array([
            ctx_p0b0['modelspecs'][0][0]['meta']['r_test'][0],
            ctx_p0b['modelspecs'][0][0]['meta']['r_test'][0],
            ctx_pb0['modelspecs'][0][0]['meta']['r_test'][0],
            ctx_pb['modelspecs'][0][0]['meta']['r_test'][0]
        ]),
        'se_test':
        np.array([
            ctx_p0b0['modelspecs'][0][0]['meta']['se_test'][0],
            ctx_p0b['modelspecs'][0][0]['meta']['se_test'][0],
            ctx_pb0['modelspecs'][0][0]['meta']['se_test'][0],
            ctx_pb['modelspecs'][0][0]['meta']['se_test'][0]
        ]),
        'r_floor':
        np.array([
            ctx_p0b0['modelspecs'][0][0]['meta']['r_floor'][0],
            ctx_p0b['modelspecs'][0][0]['meta']['r_floor'][0],
            ctx_pb0['modelspecs'][0][0]['meta']['r_floor'][0],
            ctx_pb['modelspecs'][0][0]['meta']['r_floor'][0]
        ]),
        'pred_mod':
        pred_mod.T,
        'pred_mod_full':
        pred_mod_full.T,
        'pred_mod_norm':
        pred_mod_norm.T,
        'pred_mod_full_norm':
        pred_mod_full_norm.T,
        'g':
        np.array([
            ctx_p0b0['modelspecs'][0][0]['phi']['g'],
            ctx_p0b['modelspecs'][0][0]['phi']['g'],
            ctx_pb0['modelspecs'][0][0]['phi']['g'],
            ctx_pb['modelspecs'][0][0]['phi']['g']
        ]),
        'b':
        np.array([
            ctx_p0b0['modelspecs'][0][0]['phi']['d'],
            ctx_p0b['modelspecs'][0][0]['phi']['d'],
            ctx_pb0['modelspecs'][0][0]['phi']['d'],
            ctx_pb['modelspecs'][0][0]['phi']['d']
        ]),
        'ref_all_resp':
        all_ref_mean,
        'ref_common_resp':
        ref_mean,
        'tar_max_resp':
        tar_mean[0],
        'tar_probe_resp':
        tar_mean[1]
    }

    return fh, stats
Ejemplo n.º 13
0
xfspec, ctx = xhelp.load_model_xform(cellid, batch, modelname)

if browse_results:
    ex = gui.browse_xform_fit(ctx, xfspec)
else:
    ctx['modelspec'].quickplot()


r=ctx['rec']['resp']
dual = r.get_epoch_indices('dual')
rrnd9 = r.get_epoch_indices('Stim , 10 , Reference')
rnd9 = r.get_epoch_indices('Stim , 10 , Target')
rep9 = r.get_epoch_indices('Stim , 10 , TargetRep')
rep9 = r.get_epoch_indices('target_1_repeating_dual')

a = epoch.epoch_intersection(rep9, dual)

raster_rnd = r.extract_epoch('target_0')
raster_rep = r.extract_epoch('target_0_repeating_dual')
plt.figure()
plt.subplot(2,2,1)
i, j = np.where(raster_rnd[:,0,:])
i += 1
plt.plot(j, i,' k.')
plt.title('rand')

plt.subplot(2,2,2)
i, j = np.where(raster_rep[:,0,:])
i += 1
plt.plot(j, i,' k.')
plt.title('rep')