Exemple #1
0
def _epoch_name_handler(rec_or_sig, epoch_names):
    '''
    helper function to transform heterogeneous inputs of epochs names (epoch names, list of epochs names, keywords) into
    the corresponding list of epoch names.
    :param rec_or_sig: nems recording of signal object
    :param epoch_names: epoch name (str), regexp, list of epoch names, 'single', 'pair'. keywords 'single' and 'pair'
    correspond to all single vocalization, and pair of stim_num prb vocalization pairs.
    :return: a list with the apropiate epoch names as found in signal.epoch.name
    '''
    if epoch_names == 'single':  # get eps matching 'voc_x' where x is a positive integer
        reg_ex = r'\Avoc_\d'
        epoch_names = nep.epoch_names_matching(rec_or_sig.epochs, (reg_ex))
    elif epoch_names == 'pair':  # get eps matching 'Cx_Py' where x and y are positive integers
        reg_ex = r'\AC\d_P\d'
        epoch_names = nep.epoch_names_matching(rec_or_sig.epochs, (reg_ex))
    elif isinstance(epoch_names, str):  # get eps matching the specified regexp
        reg_ex = epoch_names
        epoch_names = nep.epoch_names_matching(rec_or_sig.epochs, (reg_ex))
    elif isinstance(epoch_names, list):  # uses epoch_names as a list of epoch names.
        ep_intersection = set(epoch_names).intersection(set(rec_or_sig.epochs.name.unique()))
        if len(ep_intersection) == 0:
            raise AttributeError("specified eps are not contained in sig")
        pass

    if len(epoch_names) == 0:
        raise AttributeError("no eps match regex '{}'".format(reg_ex))

    return epoch_names
Exemple #2
0
def r_ceiling_test(result, fullrec, pred_name='pred', resp_name='resp', N=100):
    """
    Compute noise-corrected correlation coefficient based on single-trial
    correlations in the actual response.
    """
    epoch_regex = '^STIM_'
    epochs_to_extract = ep.epoch_names_matching(result[resp_name].epochs,
                                                epoch_regex)
    folded_resp = result[resp_name].extract_epochs(epochs_to_extract)

    epochs_to_extract = ep.epoch_names_matching(result[pred_name].epochs,
                                                epoch_regex)
    folded_pred = result[pred_name].extract_epochs(epochs_to_extract)

    resp = fullrec[resp_name].rasterize()

    X = np.array([])
    Y = np.array([])
    for k, d in folded_resp.items():
        if np.sum(np.isfinite(d)) > 0:

            x = resp.extract_epoch(k)
            X = np.concatenate((X, x.flatten()))

            p = folded_pred[k]
            if p.shape[0] < x.shape[0]:
                p = np.tile(p, (x.shape[0], 1, 1))
            Y = np.concatenate((Y, p.flatten()))

    # exclude nan values of X or Y
    gidx = (np.isfinite(X) & np.isfinite(Y))
    X = X[gidx]
    Y = Y[gidx]

    sx = X
    mx = X

    fit_alpha, fit_loc, fit_beta = stats.gamma.fit(X)

    mu = [
        np.reshape(
            stats.gamma.rvs(fit_alpha + sx,
                            loc=fit_loc,
                            scale=fit_beta / (1 + fit_beta)), (1, -1))
        for a in range(10)
    ]
    mu = np.concatenate(mu)
    mu[mu > np.max(X)] = np.max(X)
    xc_set = [np.corrcoef(mu[i, :], X)[0, 1] for i in range(10)]
    log.info("Simulated r_single: %.3f +/- %.3f", np.mean(xc_set),
             np.std(xc_set) / np.sqrt(10))

    xc_act = np.corrcoef(Y, X)[0, 1]
    log.info("actual r_single: %.03f", xc_act)

    # weighted average based on number of samples in each epoch
    rnorm = xc_act / np.mean(xc_set)

    return rnorm
Exemple #3
0
def compute_snr_multi(resp, frac_total=True):
    epochs = resp.epochs
    stim_epochs = ep.epoch_names_matching(epochs, 'STIM_')
    resp_dict = resp.extract_epochs(stim_epochs)

    chan_count = resp.shape[0]
    snr = np.zeros(chan_count)
    for cidx in range(chan_count):
        per_stim_snrs = []
        for stim, r in resp_dict.items():
            repcount = r.shape[0]
            if repcount > 2:
                for j in range(repcount):
                    _r = r[:, cidx, :]
                    products = np.dot(_r, _r.T)
                    per_rep_snrs = []
                    for i in range(repcount):
                        total_power = products[i, i]
                        signal_powers = np.delete(products[i], i)
                        if frac_total:
                            rep_snr = np.nanmean(signal_powers) / total_power
                        else:
                            rep_snr = np.nanmean(signal_powers /
                                                 (total_power - signal_powers))

                        per_rep_snrs.append(rep_snr)
                    per_stim_snrs.append(np.nanmean(per_rep_snrs))
        snr[cidx] = np.nanmean(per_stim_snrs)
        #print(resp.chans[cidx], snr[cidx])
    return snr
Exemple #4
0
def est_halved(half=1, seed_idx=0, random_seed=1234, **ctx):
    '''
    Use only one half of estimation data, for comparing a model to itself.
    '''
    est = ctx['est']
    epochs = est['stim'].epochs
    stims = np.array(ep.epoch_names_matching(epochs, 'STIM_'))
    indices = np.linspace(0, len(stims) - 1, len(stims), dtype=np.int)

    st0 = np.random.get_state()
    random_seed += seed_idx
    np.random.seed(random_seed)
    set1_idx = np.random.choice(indices, round(len(stims) / 2), replace=False)
    np.random.set_state(st0)

    mask = np.zeros_like(stims, np.bool)
    mask[set1_idx] = True
    set1_stims = stims[mask].tolist()
    set2_stims = stims[~mask].tolist()

    est1, est2 = est.split_by_epochs(set1_stims, set2_stims)
    if half == 1:
        est = est1
    else:
        est = est2

    return {'est': est}
Exemple #5
0
def compute_snr(resp, frac_total=True):
    epochs = resp.epochs
    stim_epochs = ep.epoch_names_matching(epochs, 'STIM_')
    resp_dict = resp.extract_epochs(stim_epochs)

    per_stim_snrs = []
    for stim, resp in resp_dict.items():
        resp = resp.squeeze()
        if resp.ndim == 1:
            # Only one stim rep, have to add back in axis for number of reps
            resp = np.expand_dims(resp, 0)
        products = np.dot(resp, resp.T)
        per_rep_snrs = []

        for i, _ in enumerate(resp):
            total_power = products[i, i]
            signal_powers = np.delete(products[i], i)
            if frac_total:
                rep_snr = np.nanmean(signal_powers) / total_power
            else:
                rep_snr = np.nanmean(signal_powers /
                                     (total_power - signal_powers))

            per_rep_snrs.append(rep_snr)
        per_stim_snrs.append(np.nanmean(per_rep_snrs))

    # if np.sum(np.isnan(per_stim_snrs)) == len(per_stim_snrs):
    #     import pdb; pdb.set_trace()

    return np.nanmean(per_stim_snrs)
Exemple #6
0
def generate_average_sig(signal_to_average,
                         new_signalname='respavg',
                         epoch_regex='^STIM_'):
    '''
    Returns a signal with a new signal created by replacing every epoch
    matched in "epoch_regex" with the average of every occurrence in that
    epoch. This is often used to make a response average signal that
    is the same length as the original signal_to_average, usually for plotting.

    Optional arguments:
       signal_to_average   The signal from which you want to create an
                           average signal. It will not be modified.
       new_signalname      The name of the new, average signal.
       epoch_regex         A regex to match which epochs to average across.
    '''

    # 1. Fold matrix over all stimuli, returning a dict where keys are stimuli
    #    and each value in the dictionary is (reps X cell X bins)
    epochs_to_extract = ep.epoch_names_matching(signal_to_average.epochs,
                                                epoch_regex)
    folded_matrices = signal_to_average.extract_epochs(epochs_to_extract)

    # 2. Average over all reps of each stim and save into dict called psth.
    per_stim_psth = dict()
    for k in folded_matrices.keys():
        per_stim_psth[k] = np.nanmean(folded_matrices[k], axis=0)

    # 3. Invert the folding to unwrap the psth into a predicted spike_dict by
    #   replacing all epochs in the signal with their average (psth)
    respavg = signal_to_average.replace_epochs(per_stim_psth)
    respavg.name = new_signalname

    return respavg
Exemple #7
0
def strf_tensor(mfilename,cellid,plot=False,linalg=False,real=False):
    '''
     Creates matrices used as features and labels of tensor
     :param mfilename: File with your data in it
     :param cellid: Name of cell
     :param fs: sampling frequency, default 1000
     :param linalg: if True, will perform and display strf generated from inputs using linear algebra as check
     :param real: if True, will run full tor_tuning function to give actual strf output for given input
     :return: matrix of stimulus, with time delay & matrix of response
     '''
    fs=1000
    rec = nb.baphy_load_recording_file(mfilename=mfilename, cellid=cellid, fs=fs, stim=False)
    globalparams, exptparams, exptevents = nio.baphy_parm_read(mfilename)
    signal = rec['resp'].rasterize(fs=fs)

    epoch_regex = "^STIM_TORC_.*"  # pick all epochs that have STIM_TORC_...
    epochs_to_extract = ep.epoch_names_matching(signal.epochs, epoch_regex)  # find those epochs
    r = signal.extract_epochs(
        epochs_to_extract)  # extract them, r.keys() yields names of TORCS that can be looked through as dic r['name']...can be np.squeeze(0, np.mean(

    all_arr = list()
    for val in r.values():
        fval = np.swapaxes(np.squeeze(val), 0, 1)
        all_arr.append(fval)
    stacked = np.stack(all_arr,axis=2)

    TorcObject = exptparams["TrialObject"][1]["ReferenceHandle"][1]
    PreStimbin = int(TorcObject['PreStimSilence'] * fs)
    PostStimbin = int(TorcObject['PostStimSilence'] * fs)
    numbin = stacked.shape[0]
    stacked = stacked[PreStimbin:(numbin - PostStimbin), :, :]
    stimall,avgResp,Params = strf_input_gen(stacked,TorcObject,exptparams,fs)

    fitH = model_strf(stimall,avgResp)

    if plot == True:
        #bf = strf_plot_prepare(fitH,Params)
        [_,_] = strfplot(fitH, Params['lfreq'], Params['basep'],smooth=1, noct=Params['octaves'])
        plt.title('%s - Linear Regression' % (os.path.basename(mfilename)), fontweight='bold')

    if linalg == True:
        #based on idea that H = Y*Xtranspose*inverse of correlation matrix
        X = stimall
        XT = np.swapaxes(stimall, 0, 1)
        Y = np.swapaxes(np.expand_dims(avgResp, axis=1), 0, 1)
        C = np.dot(X, XT)
        Cinv = np.linalg.pinv(C)

        H = np.reshape((np.dot(np.dot(Y, XT), Cinv)),(15,25),order='F')

        #bfl = strf_plot_prepare(H,Params)
        [_,_] = strfplot(H, Params['lfreq'], Params['basep'], 1, Params['octaves'])
        plt.title('%s - Control' % (os.path.basename(mfilename)), fontweight='bold')

    if real == True:
        _ = tor_tuning(mfilename,cellid,plot=True)

    return fitH
Exemple #8
0
def generate_stim_from_epochs(rec,
                              new_signal_name='stim',
                              epoch_regex='^STIM_',
                              epoch_shift=0,
                              epoch2_regex=None,
                              epoch2_shift=0,
                              epoch2_shuffle=False,
                              onsets_only=True):

    rec = rec.copy()
    resp = rec['resp'].rasterize()

    epochs_to_extract = ep.epoch_names_matching(resp.epochs, epoch_regex)
    sigs = []
    for e in epochs_to_extract:
        log.info('Adding to %s: %s with shift = %d', new_signal_name, e,
                 epoch_shift)
        s = resp.epoch_to_signal(e, onsets_only=onsets_only, shift=epoch_shift)
        if epoch_shift:
            s.chans[0] = "{}{:+d}".format(s.chans[0], epoch_shift)
        sigs.append(s)

    if epoch2_regex is not None:
        epochs_to_extract = ep.epoch_names_matching(resp.epochs, epoch2_regex)
        for e in epochs_to_extract:
            log.info('Adding to %s: %s with shift = %d', new_signal_name, e,
                     epoch2_shift)
            s = resp.epoch_to_signal(e,
                                     onsets_only=onsets_only,
                                     shift=epoch2_shift)
            if epoch2_shuffle:
                log.info('Shuffling %s', e)
                s = s.shuffle_time()
                s.chans[0] = "{}_shf".format(s.chans[0])
            if epoch_shift:
                s.chans[0] = "{}{:+d}".format(s.chans[0], epoch2_shift)
            sigs.append(s)

    stim = sigs[0].concatenate_channels(sigs)
    stim.name = new_signal_name

    # add_signal operates in place
    rec.add_signal(stim)

    return rec
Exemple #9
0
def all_nwb_epochs(resp,epoch_regex):
    #return epoch data as array - helpful if recording includes multiple units?
    epoch_list=ep.epoch_names_matching(resp.epochs, epoch_regex)
    epoch_dict=resp.extract_epochs(epoch_list)
    all_epochs=[epoch_dict[key] for key in epoch_dict.keys()]
    e_size=np.max([i.shape[2] for i in all_epochs])
    #merge everything together - resulting array has shape trials*stimsxunitsxsamples
    all_epochs=np.concatenate([np.resize(i,(50,len(resp.chans),e_size)) for i in all_epochs],0) 
    return all_epochs
def psth_per_file(rec):

    raise NotImplementedError

    resp = rec['resp'].rasterize()

    file_epochs = ep.epoch_names_matching(resp.epochs, "^FILE_")

    epoch_regex = "^STIM_"
    stim_epochs = ep.epoch_names_matching(resp.epochs, epoch_regex)

    r = []
    max_rep_id = np.zeros(len(file_epochs))
    for f in file_epochs:

        r.append(resp.as_matrix(stim_epochs, overlapping_epoch=f) * resp.fs)

    repcount = np.sum(np.isfinite(r[:, :, 0, 0]), axis=1)
    max_rep_id, = np.where(repcount == np.max(repcount))

    t = np.arange(r.shape[-1]) / resp.fs

    plt.figure()

    ax = plt.subplot(3, 1, 1)
    nplt.plot_spectrogram(s[max_rep_id[-1], 0, :, :],
                          fs=stim.fs,
                          ax=ax,
                          title="cell {} - stim".format(cellid))

    ax = plt.subplot(3, 1, 2)
    nplt.raster(t, r[max_rep_id[-1], :, 0, :], ax=ax, title='raster')

    ax = plt.subplot(3, 1, 3)
    nplt.psth_from_raster(t,
                          r[max_rep_id[-1], :, 0, :],
                          ax=ax,
                          title='raster',
                          ylabel='spk/s')

    plt.tight_layout()
Exemple #11
0
def _split_signal(signal):
    # finds in epochs the transition between one experiment and the next
    if isinstance(signal, PointProcess):
        pass
    elif isinstance(signal, RasterizedSignal):
        raise NotImplementedError('signal must be a PointPorcess')
    elif isinstance(signal, TiledSignal):
        raise NotImplementedError('signal must be a PointPorcess')
    else:
        raise ValueError('First argument must be a NEMS signal')

    epochs = signal.epochs
    epoch_names = nep.epoch_names_matching(signal.epochs, '\AFILE_[a-zA-Z]{3}\d{3}[a-z]\d{2}_[ap]_CPN\Z')
    if len(epoch_names) == 0:
        raise ValueError('Epochs do not contain files matching CPN experiments.')
    file_epochs = epochs.loc[epochs.name.isin(epoch_names), :]

    sub_signals = dict()
    trip_counter = 0
    perm_counter = 0

    for ff, (_, file) in enumerate(file_epochs.iterrows()):

        # extract relevant epochs and data
        sub_epochs = epochs.loc[(epochs.start >= file.start) & (epochs.end <= file.end), :].copy()
        sub_epochs[['start', 'end']] = sub_epochs[['start', 'end']] - file.start

        sub_data = {cell: spikes[np.logical_and(spikes >= file.start, spikes < file.end)] - file.start
                    for cell, spikes in signal._data.copy().items()}

        meta = signal.meta.copy()
        meta['rawid'] = [meta['rawid'][ff]]

        sub_signal = signal._modified_copy(data=sub_data, epochs=sub_epochs, meta=meta)

        # checks names of epochs to define triples or permutation
        # keeps track of number of trip of perm experiments
        # names the signal with the experiment type and number in case of repeated trip and/or perm
        exp_type = _detect_type(sub_epochs)
        if exp_type == 'perm':
            exp_type = f'{exp_type}{perm_counter}'
            perm_counter += 1
        elif exp_type == 'trip':
            exp_type = f'{exp_type}{trip_counter}'
            trip_counter += 1
        else:
            raise ValueError('not Permutations or Triplets')

        sub_signals[exp_type] = sub_signal

    return sub_signals
Exemple #12
0
def generate_psth_from_est_for_both_est_and_val(est,
                                                val,
                                                epoch_regex='^STIM_'):
    '''
    Estimates a PSTH from the EST set, and returns two signals based on the
    est and val, in which each repetition of a stim uses the EST PSTH?

    subtract spont rate based on pre-stim silence for ALL estimation data.
    '''

    resp_est = est['resp'].rasterize()
    resp_val = val['resp'].rasterize()

    # compute PSTH response and spont rate during those valid trials
    prestimsilence = resp_est.extract_epoch('PreStimSilence')
    if len(prestimsilence.shape) == 3:
        spont_rate = np.nanmean(prestimsilence, axis=(0, 2))
    else:
        spont_rate = np.nanmean(prestimsilence)

    epochs_to_extract = ep.epoch_names_matching(resp_est.epochs, epoch_regex)
    folded_matrices = resp_est.extract_epochs(epochs_to_extract,
                                              mask=est['mask'])

    # 2. Average over all reps of each stim and save into dict called psth.
    per_stim_psth = dict()
    for k in folded_matrices.keys():
        per_stim_psth[k] = np.nanmean(folded_matrices[k], axis=0) - \
            spont_rate[:, np.newaxis]

    # 3. Invert the folding to unwrap the psth into a predicted spike_dict by
    #   replacing all epochs in the signal with their average (psth)
    respavg_est = resp_est.replace_epochs(per_stim_psth)
    respavg_est.name = 'psth'

    # add signal to the est recording
    est.add_signal(respavg_est)

    respavg_val = resp_val.replace_epochs(per_stim_psth)
    respavg_val.name = 'psth'

    # add signal to the val recording
    val.add_signal(respavg_val)

    return est, val
Exemple #13
0
def spike_dev_by_pupil(rec,rras):
    #mean dev from psth per trial and mean pupil pertrial accross all epochs - return df of correlation coeff
    #will return data for multiple units if 
    epochs=rec.epochs
    if rras is None:
        rras=rec['resp'].rasterize()
    #extract all epochs for pupil
    epoch_regex="^natural_scene"
    epoch_list=ep.epoch_names_matching(epochs, epoch_regex)
    pupil_epochs=rec['pupil'].extract_epochs(epoch_list)
    pupil_epochs=[pupil_epochs[key] for key in pupil_epochs.keys()]
    e_size=np.max([i.shape[2] for i in pupil_epochs])
    pupil_epochs=np.squeeze(np.concatenate([np.resize(i,(50,1,e_size)) for i in pupil_epochs],0),1)
    #mean pupil over time for each stim,trial
    meanpupil=np.nanmean(pupil_epochs,1).reshape(119,50)
    
    #extract all natural_scene epochs then merge dict into array along trial num axis
    epoch_dict=rras.extract_epochs(epoch_list)
    all_epochs=[epoch_dict[key] for key in epoch_dict.keys()]
    e_size=np.max([i.shape[2] for i in all_epochs])
    #merge everything together - resulting array has shape trials*stimsxunitsxsamples
    all_epochs=np.concatenate([np.resize(i,(50,len(rras.chans),e_size)) for i in all_epochs],0) 

    
    from scipy.stats import pearsonr
      
    #iterate over cells in signal
    respdev_dict={}
    for i in range(all_epochs.shape[1]):
        chan=all_epochs[:,i,:].reshape(119,50,-1)
        meanresp_pertrial=np.nanmean(chan,1) #avg accross trials for each epoch
        respdev_dict[rras.chans[i]]=np.nanmean((chan-np.expand_dims(meanresp_pertrial,1)),2)
    
    pupilnan=(~np.isnan(meanpupil)) 
    dev_corr = {}
    #df of pupil, resp dev correlation
    for key in respdev_dict.keys():
        resp = respdev_dict[key]
        corr_index=((~np.isnan(resp))&pupilnan) #remove nans for correlation
        (coeff,pval) = pearsonr(meanpupil[corr_index],resp[corr_index]) 
        dev_corr[key]={'coeff':coeff,'pval':pval}
    dev_corr=pd.DataFrame.from_dict(dev_corr)
    return dev_corr
Exemple #14
0
def stim_resp_per_epoch(rec):
    stim = copy.deepcopy(rec['stim'].as_continuous())
    resp = copy.deepcopy(rec['resp'].as_continuous())
    fs = rec['stim'].fs
    epochs = rec.epochs
    stim_epochs = ep.epoch_names_matching(epochs, 'STIM_')
    pre_silence = _silence_duration(epochs, 'PreStimSilence')
    post_silence = _silence_duration(epochs, 'PostStimSilence')

    stims = []
    resps = []
    for s in stim_epochs:
        row = epochs[epochs.name == s]
        start = int((row['start'].values[0] + pre_silence)*fs)
        end = int((row['end'].values[0] - post_silence)*fs)

        stims.append(stim[:, start:end])
        resps.append(resp[start:end])

    return stims, resps
Exemple #15
0
def mean_sd_per_stim_by_cellid(cellid, batch, loadkey='ozgf.fs100.ch18',
                               max_db_scale=65, pre_log_floor=1,
                               stims_to_skip=[]):
    rec_path = xwrap.generate_recording_uri(cellid, batch, loadkey=loadkey)
    rec = nems.recording.load_recording(rec_path)
    stim = copy.deepcopy(rec['stim'].as_continuous())
    fs = rec['stim'].fs
    epochs = rec.epochs
    stim_epochs = ep.epoch_names_matching(epochs, 'STIM_')
    stim_epochs = [s for s in stim_epochs if s not in stims_to_skip]
    pre_silence = silence_duration(epochs, 'PreStimSilence')
    post_silence = silence_duration(epochs, 'PostStimSilence')

    results = {}
    for s in stim_epochs:
        row = epochs[epochs.name == s]
        start = int((row['start'].values[0] + pre_silence)*fs)
        end = int((row['end'].values[0] - post_silence)*fs)
        results[s] = spectrogram_mean_sd(stim[:, start:end],
                                         max_db_scale=max_db_scale,
                                         pre_log_floor=pre_log_floor)

    return results
Exemple #16
0
def compute_snr(resp, frac_total=True):
    epochs = resp.epochs
    stim_epochs = ep.epoch_names_matching(epochs, 'STIM_')
    resp_dict = resp.extract_epochs(stim_epochs)

    per_stim_snrs = []
    for stim, resp in resp_dict.items():
        resp = resp.squeeze()
        products = np.dot(resp, resp.T)
        per_rep_snrs = []
        for i, _ in enumerate(resp):
            total_power = products[i, i]
            signal_powers = np.delete(products[i], i)
            if frac_total:
                rep_snr = np.nanmean(signal_powers) / total_power
            else:
                rep_snr = np.nanmean(signal_powers /
                                     (total_power - signal_powers))

            per_rep_snrs.append(rep_snr)
        per_stim_snrs.append(np.nanmean(per_rep_snrs))

    return np.nanmean(per_stim_snrs)
Exemple #17
0
def nwb_resp_psth(rec,epoch_regex):
#intended to give similar output generate_psth_from_resp from preprocessing model, but works better w/ structure
#of neuropixels data?
    newrec=rec.copy()

    resp=newrec['resp'].rasterize()

    #epoch_regex="^natural_scene"
    #extract all natural_scene epochs then merge dict and avg - add new signal
    epochs_to_extract = ep.epoch_names_matching(resp.epochs, epoch_regex)
    #epoch_dict=resp.extract_epochs(epochs.loc[epochs.name.str.contains('natural_scene'),'name'])
    epoch_dict=resp.extract_epochs(epochs_to_extract)
    
    #no pre/post stim silence but spontaneous intervals work instead? - or use psth w/out mean subtracted
    spont = resp.extract_epoch('spontaneous')
    spont_rate=np.nanmean(spont)
    
    per_stim_psth_spont = {}
    per_stim_psth = {}
    for k, v in epoch_dict.items():
        per_stim_psth_spont[k] = np.nanmean(v, axis=0)
        per_stim_psth[k] = np.nanmean(v, axis=0) - spont_rate
        
    respavg = resp.replace_epochs(per_stim_psth)
    respavg.name = 'psth'
    respavg_data = respavg.as_continuous().copy()

    respavg_with_spont = resp.replace_epochs(per_stim_psth_spont)
    respavg_with_spont.name = 'psth_sp'
    respavg_spont_data = respavg_with_spont.as_continuous().copy()
    
    respavg = respavg._modified_copy(respavg_data)
    respavg_with_spont = respavg_with_spont._modified_copy(respavg_spont_data)
    
    newrec.add_signal(respavg)
    newrec.add_signal(respavg_with_spont)
    return newrec
Exemple #18
0
def test_plots():
    recording.get_demo_recordings(name=recording_file)
    rec = recording.load_recording(uri)

    resp = rec['resp'].rasterize()
    stim = rec['stim'].rasterize()

    epoch_regex = "^STIM_"

    stim_epochs = ep.epoch_names_matching(resp.epochs, epoch_regex)

    r = resp.as_matrix(stim_epochs) * resp.fs
    s = stim.as_matrix(stim_epochs)
    repcount = np.sum(np.isfinite(r[:, :, 0, 0]), axis=1)
    max_rep_id, = np.where(repcount == np.max(repcount))

    t = np.arange(r.shape[-1]) / resp.fs

    plt.figure()

    ax = plt.subplot(3, 1, 1)
    nplt.plot_spectrogram(s[max_rep_id[-1], 0, :, :],
                          fs=stim.fs,
                          ax=ax,
                          title="cell {} - stim".format(cellid))

    ax = plt.subplot(3, 1, 2)
    nplt.raster(t, r[max_rep_id[-1], :, 0, :], ax=ax, title='raster')

    ax = plt.subplot(3, 1, 3)
    nplt.psth_from_raster(t,
                          r[max_rep_id[-1], :, 0, :],
                          ax=ax,
                          title='raster',
                          ylabel='spk/s')

    plt.tight_layout()
Exemple #19
0
        'st.pca.pup+r1'
    ]
    resp_modelname = f"psth.fs4.pup-ld-hrc-psthfr.z-pca.cc1.no.p-{states[-1]}-plgsm.p2-aev-rd.resp"+\
                "_stategain.2xR.x1,3-spred-lvnorm.4xR.so.x2-inoise.4xR.x3"+\
                "_tfinit.xx0.n.lr1e4.cont.et4.i20-lvnoise.r4-aev-ccnorm.md.t1.f0.ss3"

modelnames = [resp_modelname] + [modelname_base.format(s) for s in states]

xf, ctx = load_model_xform(cellid=cellid,
                           batch=batch,
                           modelname=modelnames[-1])

val = ctx['val'].copy()
resp = val['resp'].rasterize()
epoch_regex = "^STIM_"
epochs = ep.epoch_names_matching(resp.epochs, regex_str=epoch_regex)

input_name = 'pred0'
pred0 = val[input_name].extract_epochs(epochs, mask=val['mask'])
pred = val['pred'].extract_epochs(epochs, mask=val['mask'])
resp = val['resp'].extract_epochs(epochs, mask=val['mask'])
pupil = val['pupil'].extract_epochs(epochs, mask=val['mask'])
pmedian = np.nanmedian(val['pupil'].as_continuous())

epochs = list(resp.keys())
epochs

#
#ncells, nreps, nstim, nbins = X.shape

e1 = epochs[0]
Exemple #20
0
def generate_psth_from_resp(rec, epoch_regex='^STIM_', smooth_resp=False):
    '''
    Estimates a PSTH from all responses to each regex match in a recording

    subtract spont rate based on pre-stim silence for ALL estimation data.

    if rec['mask'] exists, uses rec['mask'] == True to determine valid epochs
    '''
    newrec = rec.copy()
    resp = newrec['resp'].rasterize()

    # compute spont rate during valid (non-masked) trials
    if 'mask' in newrec.signals.keys():
        prestimsilence = resp.extract_epoch('PreStimSilence',
                                            mask=newrec['mask'])
    else:
        prestimsilence = resp.extract_epoch('PreStimSilence')

    if len(prestimsilence.shape) == 3:
        spont_rate = np.nanmean(prestimsilence, axis=(0, 2))
    else:
        spont_rate = np.nanmean(prestimsilence)

    idx = resp.get_epoch_indices('PreStimSilence')
    prebins = idx[0][1] - idx[0][0]
    idx = resp.get_epoch_indices('PostStimSilence')
    postbins = idx[0][1] - idx[0][0]

    # compute PSTH response during valid trials
    if type(epoch_regex) == list:
        epochs_to_extract = []
        for rx in epoch_regex:
            eps = ep.epoch_names_matching(resp.epochs, rx)
            epochs_to_extract += eps

    elif type(epoch_regex) == str:
        epochs_to_extract = ep.epoch_names_matching(resp.epochs, epoch_regex)

    if 'mask' in newrec.signals.keys():
        folded_matrices = resp.extract_epochs(epochs_to_extract,
                                              mask=newrec['mask'])
    else:
        folded_matrices = resp.extract_epochs(epochs_to_extract)

    # 2. Average over all reps of each stim and save into dict called psth.
    per_stim_psth = dict()
    per_stim_psth_spont = dict()
    for k, v in folded_matrices.items():
        if smooth_resp:
            # replace each epoch (pre, during, post) with average
            v[:, :, :prebins] = np.nanmean(v[:, :, :prebins],
                                           axis=2,
                                           keepdims=True)
            v[:, :,
              prebins:(prebins + 2)] = np.nanmean(v[:, :,
                                                    prebins:(prebins + 2)],
                                                  axis=2,
                                                  keepdims=True)
            v[:, :,
              (prebins + 2):-postbins] = np.nanmean(v[:, :,
                                                      (prebins + 2):-postbins],
                                                    axis=2,
                                                    keepdims=True)
            v[:, :, -postbins:(-postbins + 2)] = np.nanmean(
                v[:, :, -postbins:(-postbins + 2)], axis=2, keepdims=True)
            v[:, :, (-postbins + 2):] = np.nanmean(v[:, :, (-postbins + 2):],
                                                   axis=2,
                                                   keepdims=True)

        per_stim_psth[k] = np.nanmean(v, axis=0) - spont_rate[:, np.newaxis]
        per_stim_psth_spont[k] = np.nanmean(v, axis=0)
        folded_matrices[k] = v

    # 3. Invert the folding to unwrap the psth into a predicted spike_dict by
    #   replacing all epochs in the signal with their average (psth)
    respavg = resp.replace_epochs(per_stim_psth)
    respavg_with_spont = resp.replace_epochs(per_stim_psth_spont)
    respavg.name = 'psth'
    respavg_with_spont.name = 'psth_sp'

    # Fill in a all non-masked periods with 0 (presumably, these are spont
    # periods not contained within stimulus epochs), or spont rate (for the signal
    # containing spont rate)
    respavg_data = respavg.as_continuous().copy()
    respavg_spont_data = respavg_with_spont.as_continuous().copy()

    if 'mask' in newrec.signals.keys():
        mask_data = newrec['mask']._data
    else:
        mask_data = np.ones(respavg_data.shape).astype(np.bool)

    spont_periods = ((np.isnan(respavg_data)) & (mask_data == True))

    respavg_data[:, spont_periods[0, :]] = 0
    # respavg_spont_data[:, spont_periods[0,:]] = spont_rate[:, np.newaxis]

    respavg = respavg._modified_copy(respavg_data)
    respavg_with_spont = respavg_with_spont._modified_copy(respavg_spont_data)

    # add the new signals to the recording
    newrec.add_signal(respavg)
    newrec.add_signal(respavg_with_spont)

    if smooth_resp:
        log.info('Replacing resp with smoothed resp')
        resp = resp.replace_epochs(folded_matrices, mask=newrec['mask'])
        newrec.add_signal(resp)

    return newrec
Exemple #21
0
def mask_all_but_correct_references(rec,
                                    balance_rep_count=False,
                                    include_incorrect=False):
    """
    Specialized function for removing incorrect trials from data
    collected using baphy during behavior.

    TODO: Migrate to nems_lbhb and/or make a more generic version
    """

    newrec = rec.copy()
    newrec['resp'] = newrec['resp'].rasterize()
    if 'stim' in newrec.signals.keys():
        newrec['stim'] = newrec['stim'].rasterize()
    resp = newrec['resp']

    if balance_rep_count:

        epoch_regex = "^STIM_"
        epochs_to_extract = ep.epoch_names_matching(resp.epochs, epoch_regex)
        p = resp.get_epoch_indices("PASSIVE_EXPERIMENT")
        a = resp.get_epoch_indices("HIT_TRIAL")

        epoch_list = []
        for s in epochs_to_extract:
            e = resp.get_epoch_indices(s)
            pe = ep.epoch_intersection(e, p)
            ae = ep.epoch_intersection(e, a)
            if len(pe) > len(ae):
                epoch_list.extend(ae)
                subset = np.round(np.linspace(0, len(pe),
                                              len(ae) + 1)).astype(int)
                for i in subset[:-1]:
                    epoch_list.append(pe[i])
            else:
                subset = np.round(np.linspace(0, len(ae),
                                              len(pe) + 1)).astype(int)
                for i in subset[:-1]:
                    epoch_list.append(ae[i])
                epoch_list.extend(pe)

        newrec = newrec.create_mask(epoch_list)

    elif include_incorrect:
        log.info('INCLUDING ALL TRIALS (CORRECT AND INCORRECT)')
        newrec = newrec.and_mask(['REFERENCE'])

    else:
        newrec = newrec.and_mask(['PASSIVE_EXPERIMENT', 'HIT_TRIAL'])
        newrec = newrec.and_mask(['REFERENCE'])

    # figure out if some actives should be masked out


#    t = ep.epoch_names_matching(resp.epochs, "^TAR_")
#    tm = [tt[:-2] for tt in t]  # trim last digits
#    active_epochs = resp.get_epoch_indices("ACTIVE_EXPERIMENT")
#    if len(set(tm)) > 1 and len(active_epochs) > 1:
#        print('Multiple targets: ', tm)
#        files = ep.epoch_names_matching(resp.epochs, "^FILE_")
#        keep_files = files
#        e = active_epochs[1]
#        for i,f in enumerate(files):
#            fi = resp.get_epoch_indices(f)
#            if any(ep.epoch_contains([e], fi, 'both')):
#                keep_files = files[:i]
#
#        print('Print keeping files: ', keep_files)
#        newrec = newrec.and_mask(keep_files)

    if 'state' in newrec.signals:
        b_states = [
            'far', 'hit', 'lick', 'puretone_trials', 'easy_trials',
            'hard_trials'
        ]
        trec = newrec.copy()
        trec = trec.and_mask(['ACTIVE_EXPERIMENT'])
        st = trec['state'].as_continuous().copy()
        str = trec['state_raw'].as_continuous().copy()
        mask = trec['mask'].as_continuous()[0, :]
        for s in trec['state'].chans:
            if s in b_states:
                i = trec['state'].chans.index(s)
                m = np.nanmean(st[i, mask])
                sd = np.nanstd(st[i, mask])
                # print("{} {}: m={}, std={}".format(s, i, m, sd))
                # print(np.sum(mask))
                st[i, mask] -= m
                st[i, mask] /= sd
                str[i, mask] -= m
                str[i, mask] /= sd
        newrec['state'] = newrec['state']._modified_copy(st)
        newrec['state_raw'] = newrec['state_raw']._modified_copy(str)

    return newrec
def average_away_epoch_occurrences(rec, epoch_regex='^STIM_'):
    '''
    Returns a recording with _all_ signals averaged across epochs that
    match epoch_regex, shortening them so that each epoch occurs only
    once in the new signals. i.e. unlike 'add_average_sig', the new
    recording will have signals 3x shorter if there are 3 occurrences of
    every epoch.

    This has advantages:
    1. Averaging the value of a signal (such as a response) in different
       occurrences will make it behave more like a linear variable with
       gaussian noise, which is advantageous in many circumstances.
    2. There will be less computation needed because the signal is shorter.

    It also has disadvantages:
    1. Stateful filters (FIR, IIR) will be subtly wrong near epoch boundaries
    2. Any ordering of epochs is essentially lost, unless all epochs appear
       in a perfectly repeated order.

    To avoid accidentally averaging away differences in responses to stimuli
    that are based on behavioral state, you may need to create new epochs
    (based on stimulus and behaviorial state, for example) and then match
    the epoch_regex to those.
    '''

    # Create new recording
    newrec = rec.copy()

    counter = 0

    # iterate through each signal
    for signal_name, signal_to_average in rec.signals.items():
        # TODO: for TiledSignals, there is a much simpler way to do this!

        # 0. rasterize
        signal_to_average = signal_to_average.rasterize()

        # 1. Find matching epochs
        epochs_to_extract = ep.epoch_names_matching(rec.epochs, epoch_regex)

        # 2. Fold over all stimuli, returning a dict where keys are stimuli
        #    and each value in the dictionary is (reps X cell X bins)
        folded_matrices = signal_to_average.extract_epochs(epochs_to_extract)

        # force a standard list of sorted keys for all signals
        if counter == 0:
            sorted_keys = list(folded_matrices.keys())
            sorted_keys.sort()
        counter += 1

        # 3. Average over all occurrences of each epoch, and append to data
        fs = signal_to_average.fs
        data = np.zeros([signal_to_average.nchans, 0])
        current_time = 0
        epochs = None
        for k in sorted_keys:
            # Supress warnings about all-nan matrices
            with warnings.catch_warnings():
                warnings.simplefilter("ignore")
                per_stim_psth = np.nanmean(folded_matrices[k], axis=0)
            data = np.concatenate((data, per_stim_psth), axis=1)
            epoch = current_time + np.array([[0, per_stim_psth.shape[1] / fs]])
            df = pd.DataFrame(np.tile(epoch, [2, 1]), columns=['start', 'end'])
            df['name'] = k
            df.at[1, 'name'] = 'TRIAL'
            if epochs is not None:
                epochs = epochs.append(df, ignore_index=True)
            else:
                epochs = df
            current_time = epoch[0, 1]
            #print("{0} epoch: {1}-{2}".format(k,epoch[0,0],epoch[0,1]))

        avg_signal = signal.RasterizedSignal(
            fs=fs,
            data=data,
            name=signal_to_average.name,
            recording=signal_to_average.recording,
            chans=signal_to_average.chans,
            epochs=epochs,
            meta=signal_to_average.meta)
        newrec.add_signal(avg_signal)

    return newrec
def generate_psth_from_est_for_both_est_and_val(est, val):
    '''
    Estimates a PSTH from the EST set, and returns two signals based on the
    est and val, in which each repetition of a stim uses the EST PSTH?

    subtract spont rate based on pre-stim silence for ALL estimation data.
    '''

    epoch_regex = '^STIM_'
    resp_est = est['resp'].copy()
    resp_val = val['resp']

    # find all valid references in est data-- passive or correct trials
    ref_phase = resp_est.epoch_to_signal('REFERENCE')
    active_phase = resp_est.epoch_to_signal('ACTIVE_EXPERIMENT')
    correct_phase = resp_est.epoch_to_signal('HIT_TRIAL')
    valid_phase = np.logical_and(
        ref_phase.as_continuous(),
        np.logical_or(np.logical_not(active_phase.as_continuous()),
                      correct_phase.as_continuous()))
    ref_phase = ref_phase._modified_copy(valid_phase)
    resp_est = resp_est.nan_mask(ref_phase.as_continuous())

    # compute PSTH response and spont rate during those valid trials
    prestimsilence = resp_est.extract_epoch('PreStimSilence')
    if len(prestimsilence.shape) == 3:
        spont_rate = np.nanmean(prestimsilence, axis=(0, 2))
    else:
        spont_rate = np.nanmean(prestimsilence)

    epochs_to_extract = ep.epoch_names_matching(resp_est.epochs, epoch_regex)
    folded_matrices = resp_est.extract_epochs(epochs_to_extract)

    # 2. Average over all reps of each stim and save into dict called psth.
    per_stim_psth = dict()
    for k in folded_matrices.keys():
        per_stim_psth[k] = np.nanmean(folded_matrices[k],
                                      axis=0) - spont_rate[:, np.newaxis]

    # 3. Invert the folding to unwrap the psth into a predicted spike_dict by
    #   replacing all epochs in the signal with their average (psth)
    respavg_est = resp_est.replace_epochs(per_stim_psth)
    respavg_est.name = 'stim'  # TODO: SVD suggests rename 2018-03-08

    # mark invalid phases as nan
    respavg_est = respavg_est.nan_mask(ref_phase.as_continuous())

    # add signal to the recording
    est.add_signal(respavg_est)

    respavg_val = resp_val.replace_epochs(per_stim_psth)
    respavg_val.name = 'stim'  # TODO: SVD suggests rename 2018-03-08
    ref_phase = val['resp'].epoch_to_signal('REFERENCE')
    active_phase = val['resp'].epoch_to_signal('ACTIVE_EXPERIMENT')
    correct_phase = val['resp'].epoch_to_signal('HIT_TRIAL')
    valid_phase = np.logical_and(
        ref_phase.as_continuous(),
        np.logical_or(np.logical_not(active_phase.as_continuous()),
                      correct_phase.as_continuous()))
    ref_phase = ref_phase._modified_copy(valid_phase)

    respavg_val = respavg_val.nan_mask(ref_phase.as_continuous())

    # add signal to the recording
    val.add_signal(respavg_val)

    return est, val
Exemple #24
0
#est, val = estimation, validation data sets
est, val = rec.split_using_epoch_occurrence_counts(epoch_regex="^STIM_")
est = average_away_epoch_occurrences(est, epoch_regex="^STIM_")
val = average_away_epoch_occurrences(val, epoch_regex="^STIM_")

# get matrices for fitting:
X_est = est['stim'].apply_mask().as_continuous()  # frequency x time
Y_est = est['resp'].apply_mask().as_continuous()  # neuron x time

# get matrices for testing model predictions:
X_val = val['stim'].apply_mask().as_continuous()
Y_val = val['resp'].apply_mask().as_continuous()

# find a stimulus to display
epoch_regex = '^STIM_'
epochs_to_extract = ep.epoch_names_matching(val.epochs, epoch_regex)
epoch = epochs_to_extract[0]

plt.figure()
ax = plt.subplot(3, 1, 1)
nplt.spectrogram_from_epoch(val['stim'], epoch, ax=ax, time_offset=2)

ax = plt.subplot(3, 1, 2)
nplt.timeseries_from_epoch([val['resp']], epoch, ax=ax)

raster = rec['resp'].extract_epoch(epoch)
ax = plt.subplot(3, 1, 3)
plt.imshow(raster[:, 0, :])

plt.tight_layout()
Exemple #25
0
cells_sets = dict()
for cellid in first_cell:

    # cellid='bbl099g-04-1'
    try:
        ctx=baphy_load_wrapper(cellid=cellid, batch=batch, loadkey=loadkey, siteid=None)
        rec=load_recording(ctx['recording_uri_list'][0])
    except:
        print('could not load site for cell {}'.format(cellid))
        bad_cells.append(cellid)
        continue

    sig = rec['resp'].rasterize()

    #wav_names = ne.epoch_names_matching(sig.epochs, r'^STIM_00') # for validation set only
    wav_names = ne.epoch_names_matching(sig.epochs, r'^STIM_')
    epre = sig.get_epoch_indices('PreStimSilence')
    epost = sig.get_epoch_indices('PostStimSilence')
    prebins = epre[0][1] - epre[0][0]
    postbins = epost[0][1] - epost[0][0]

    mtx_dict = sig.extract_epochs(wav_names)
    # orderse the response of all the cells in the site to all the sounds in a 3d array of shape Sound x Cells x Time
    # then normalizes the response of each cell across all stimuli
    shape = [len(mtx_dict.keys()),
             list(mtx_dict.values())[0].shape[1],
             np.max([val.shape[2] for val in mtx_dict.values()])]
    site_arr = np.empty(shape); site_arr[:] = np.nan
    for ss, (sound, response) in enumerate(mtx_dict.items()):
        PSTH = np.mean(response, axis=0)
        site_arr[ss, :, :] = PSTH
Exemple #26
0
import nems.plots.api as nplt
from nems.recording import Recording

# specify directories for loading data and fitted modelspec
signals_dir = '../signals'
#signals_dir = '/home/jacob/auto/data/batch271_fs100_ozgf18/'
modelspecs_dir = '../modelspecs'

# load the data
rec = Recording.load(os.path.join(signals_dir, 'TAR010c-18-1'))

# Add a new signal, respavg, to the recording, in 4 steps

# 1. Fold matrix over all stimuli, returning a dictionary where keys are stimuli
#    and each value in the dictionary is (reps X cell X bins)
epochs_to_extract = ep.epoch_names_matching(rec.epochs, '^STIM_')
folded_matrix = rec['resp'].extract_epochs(epochs_to_extract)

# 2. Average over all reps of each stim and save into dict called psth.
per_stim_psth = dict()
for k in folded_matrix.keys():
    per_stim_psth[k] = np.nanmean(folded_matrix[k], axis=0)

# 3. Invert the folding to unwrap the psth back out into a predicted spike_dict by
# simply replacing all epochs in the signal with their psth
respavg = rec['resp'].replace_epochs(per_stim_psth)
respavg.name = 'respavg'

# 4. Now add the signal to the recording
rec.add_signal(respavg)
Exemple #27
0
def r_ceiling(result, fullrec, pred_name='pred', resp_name='resp', N=100):
    """
    parameter:
        result : recording
            validation data containing resp_name and pred_name signals
        fullrec : orginal recording that isn't averaged across reps
        N : int
            number of random single trial pairs to test

    returns:
        rnorm: nparray
           corrected ceiling measure for each response channel (ie,
           there should be support for multiple neural channels)

    Compute noise-corrected correlation coefficient based on single-trial
    correlations in the actual response. Based on method in
    Hsu and Theusnissen (2004) Network.

    SVD revised 2018-08-30 to hopefully make more stable. Instead of computing
    average single-trial corr from separate per-stimulus measurements, now
    concatenates one rep of each validation stimulus into a long vector for
    calculating a corr coeff across all stimuli. Still repeats this for a
    bunch of pairs to get a good estimate of correlation between single trials
    """

    epoch_regex = '^STIM_'
    epochs_to_extract = ep.epoch_names_matching(result[resp_name].epochs,
                                                epoch_regex)
    folded_resp = result[resp_name].extract_epochs(epochs_to_extract)

    epochs_to_extract = ep.epoch_names_matching(result[pred_name].epochs,
                                                epoch_regex)
    folded_pred = result[pred_name].extract_epochs(epochs_to_extract)

    resp = fullrec[resp_name].rasterize()

    chancount = fullrec[resp_name].shape[0]

    rnorm = np.zeros(chancount)
    for chanidx in range(chancount):
        Xall = []
        p = []
        reps = []
        preps = []
        for k, d in folded_resp.items():
            if np.sum(np.isfinite(d)) > 0:

                Xall.append(resp.extract_epoch(k)[:, chanidx, :])
                p.append(folded_pred[k][:, chanidx, :])
                reps.append(Xall[-1].shape[0])
                preps.append(p[-1].shape[0])

        if Xall == []:
            return 0

        minreps = np.min(reps)
        X = [x[:minreps, :] for x in Xall]
        X = np.concatenate(X, axis=1)

        minpreps = np.min(preps)
        p = [p0[:minpreps, :] for p0 in p]
        p = np.concatenate(p, axis=1)
        if minreps > 1:
            rac = _r_single(X, N)

            repcount = X.shape[0]
            rs = np.zeros(repcount)
            for nn in range(repcount):
                X1 = X[nn, :]
                X2 = p[0, :]

                # remove all nans from pred and resp
                ff = np.isfinite(X1) & np.isfinite(X2)
                X1 = X1[ff]
                X2 = X2[ff]

                if (np.sum(ff) == 0) or \
                   (np.sum(X1) == 0) or \
                   (len(np.unique(X2)) == 1):
                    rs[nn] = 0
                else:
                    rs[nn] = np.corrcoef(X1, X2)[0, 1]

            rnorm[chanidx] = np.mean(rs) / np.sqrt(rac)
        else:
            rnorm[chanidx] = 0
    return rnorm
Exemple #28
0
def make_state_signal(rec,
                      state_signals=['pupil'],
                      permute_signals=[],
                      new_signalname='state'):
    """
    generate state signal for stategain.S/sdexp.S models

    valid state signals include (incomplete list):
        pupil, pupil_ev, pupil_bs, pupil_psd
        active, each_file, each_passive, each_half
        far, hit, lick, p_x_a

    TODO: Migrate to nems_lbhb or make a more generic version
    """

    newrec = rec.copy()
    resp = newrec['resp'].rasterize()

    # normalize mean/std of pupil trace if being used
    if ('pupil' in state_signals) or ('pupil_ev' in state_signals) or \
       ('pupil_bs' in state_signals):
        # normalize min-max
        p = newrec["pupil"].as_continuous().copy()
        # p[p < np.nanmax(p)/5] = np.nanmax(p)/5
        p -= np.nanmean(p)
        p /= np.nanstd(p)
        newrec["pupil"] = newrec["pupil"]._modified_copy(p)

    if ('pupil_psd') in state_signals:
        pup = newrec['pupil'].as_continuous().copy()
        fs = newrec['pupil'].fs
        # get spectrogram of pupil
        nperseg = int(60 * fs)
        noverlap = nperseg - 1
        f, time, Sxx = ss.spectrogram(pup.squeeze(),
                                      fs=fs,
                                      nperseg=nperseg,
                                      noverlap=noverlap)
        max_chan = 4  # (np.abs(f - 0.1)).argmin()
        # Keep only first five channels of spectrogram
        #f = interpolate.interp1d(np.arange(0, Sxx.shape[1]), Sxx[:max_chan, :], axis=1)
        #newspec = f(np.linspace(0, Sxx.shape[-1]-1, pup.shape[-1]))
        pad1 = np.ones((max_chan, int(nperseg / 2))) * Sxx[:max_chan, [0]]
        pad2 = np.ones((max_chan, int(nperseg / 2 - 1))) * Sxx[:max_chan, [-1]]
        newspec = np.concatenate((pad1, Sxx[:max_chan, :], pad2), axis=1)

        # = np.concatenate((Sxx[:max_chan, :], np.tile(Sxx[:max_chan,-1][:, np.newaxis], [1, noverlap])), axis=1)
        newspec -= np.nanmean(newspec, axis=1, keepdims=True)
        newspec /= np.nanstd(newspec, axis=1, keepdims=True)

        spec_signal = newrec['pupil']._modified_copy(newspec)
        spec_signal.name = 'pupil_psd'
        chan_names = []
        for chan in range(0, newspec.shape[0]):
            chan_names.append('puppsd{0}'.format(chan))
        spec_signal.chans = chan_names

        newrec.add_signal(spec_signal)

    if ('pupil_ev' in state_signals) or ('pupil_bs' in state_signals):
        # generate separate pupil baseline and evoked signals

        prestimsilence = newrec["pupil"].extract_epoch('PreStimSilence')
        spont_bins = prestimsilence.shape[2]
        pupil_trial = newrec["pupil"].extract_epoch('TRIAL')

        pupil_bs = np.zeros(pupil_trial.shape)
        for ii in range(pupil_trial.shape[0]):
            pupil_bs[ii, :, :] = np.mean(pupil_trial[ii, :, :spont_bins])
        pupil_ev = pupil_trial - pupil_bs

        newrec['pupil_ev'] = newrec["pupil"].replace_epoch('TRIAL', pupil_ev)
        newrec['pupil_ev'].chans = ['pupil_ev']
        newrec['pupil_bs'] = newrec["pupil"].replace_epoch('TRIAL', pupil_bs)
        newrec['pupil_bs'].chans = ['pupil_bs']

    if ('each_passive' in state_signals):
        file_epochs = ep.epoch_names_matching(resp.epochs, "^FILE_")
        pset = []
        found_passive1 = False
        for f in file_epochs:
            # test if passive expt
            epoch_indices = ep.epoch_intersection(
                resp.get_epoch_indices(f),
                resp.get_epoch_indices('PASSIVE_EXPERIMENT'))
            if epoch_indices.size:
                if not (found_passive1):
                    # skip first passive
                    found_passive1 = True
                else:
                    pset.append(f)
                    newrec[f] = resp.epoch_to_signal(f)
        state_signals.remove('each_passive')
        state_signals.extend(pset)
        if 'each_passive' in permute_signals:
            permute_signals.remove('each_passive')
            permute_signals.extend(pset)

    if ('each_file' in state_signals):
        file_epochs = ep.epoch_names_matching(resp.epochs, "^FILE_")
        trial_indices = resp.get_epoch_indices('TRIAL')
        passive_indices = resp.get_epoch_indices('PASSIVE_EXPERIMENT')
        pset = []
        pcount = 0
        acount = 0
        for f in file_epochs:
            # test if passive expt
            f_indices = resp.get_epoch_indices(f)
            epoch_indices = ep.epoch_intersection(f_indices, passive_indices)

            if epoch_indices.size:
                # this is a passive file
                name1 = "PASSIVE_{}".format(pcount)
                pcount += 1
                if pcount == 1:
                    acount = 1  # reset acount for actives after first passive
                else:
                    # use first passive part A as baseline - don't model
                    pset.append(name1)
                    newrec[name1] = resp.epoch_to_signal(name1,
                                                         indices=f_indices)

            else:
                name1 = "ACTIVE_{}".format(acount)
                pset.append(name1)
                newrec[name1] = resp.epoch_to_signal(name1, indices=f_indices)

                if pcount == 0:
                    acount -= 1
                else:
                    acount += 1

            # test if passive expt


#            epoch_indices = ep.epoch_intersection(
#                    resp.get_epoch_indices(f),
#                    resp.get_epoch_indices('PASSIVE_EXPERIMENT'))
#            if epoch_indices.size and not(found_passive1):
#                # skip first passive
#                found_passive1 = True
#            else:
#                pset.append(f)
#                newrec[f] = resp.epoch_to_signal(f)
        state_signals.remove('each_file')
        state_signals.extend(pset)
        if 'each_file' in permute_signals:
            permute_signals.remove('each_file')
            permute_signals.extend(pset)

    if ('each_half' in state_signals):
        file_epochs = ep.epoch_names_matching(resp.epochs, "^FILE_")
        trial_indices = resp.get_epoch_indices('TRIAL')
        passive_indices = resp.get_epoch_indices('PASSIVE_EXPERIMENT')
        pset = []
        pcount = 0
        acount = 0
        for f in file_epochs:
            # test if passive expt
            f_indices = resp.get_epoch_indices(f)
            epoch_indices = ep.epoch_intersection(f_indices, passive_indices)
            trial_intersect = ep.epoch_intersection(f_indices, trial_indices)
            #trial_count = trial_intersect.shape[0]
            #_split = int(trial_count/2)
            _t1 = trial_intersect[0, 0]
            _t2 = trial_intersect[-1, 1]
            _split = int((_t1 + _t2) / 2)
            epoch1 = np.array([[_t1, _split]])
            epoch2 = np.array([[_split, _t2]])

            if epoch_indices.size:
                # this is a passive file
                name1 = "PASSIVE_{}_{}".format(pcount, 'A')
                name2 = "PASSIVE_{}_{}".format(pcount, 'B')
                pcount += 1
                if pcount == 1:
                    acount = 1  # reset acount for actives after first passive
                else:
                    # don't model PASSIVE_0 A -- baseline
                    pset.append(name1)
                    newrec[name1] = resp.epoch_to_signal(name1, indices=epoch1)

                # do include part B
                pset.append(name2)
                newrec[name2] = resp.epoch_to_signal(name2, indices=epoch2)
            else:
                name1 = "ACTIVE_{}_{}".format(acount, 'A')
                name2 = "ACTIVE_{}_{}".format(acount, 'B')
                pset.append(name1)
                newrec[name1] = resp.epoch_to_signal(name1, indices=epoch1)
                pset.append(name2)
                newrec[name2] = resp.epoch_to_signal(name2, indices=epoch2)

                if pcount == 0:
                    acount -= 1
                else:
                    acount += 1

        state_signals.remove('each_half')
        state_signals.extend(pset)
        if 'each_half' in permute_signals:
            permute_signals.remove('each_half')
            permute_signals.extend(pset)

    # generate task state signals
    if 'pas' in state_signals:
        fpre = (resp.epochs['name'] == "PRE_PASSIVE")
        fpost = (resp.epochs['name'] == "POST_PASSIVE")
        INCLUDE_PRE_POST = (np.sum(fpre) > 0) & (np.sum(fpost) > 0)
        if INCLUDE_PRE_POST:
            # only include pre-passive if post-passive also exists
            # otherwise the regression gets screwed up
            newrec['pre_passive'] = resp.epoch_to_signal('PRE_PASSIVE')
        else:
            # place-holder, all zeros
            newrec['pre_passive'] = resp.epoch_to_signal('XXX')
            newrec['pre_passive'].chans = ['PRE_PASSIVE']
    if 'puretone_trials' in state_signals:
        newrec['puretone_trials'] = resp.epoch_to_signal('PURETONE_BEHAVIOR')
        newrec['puretone_trials'].chans = ['puretone_trials']
    if 'easy_trials' in state_signals:
        newrec['easy_trials'] = resp.epoch_to_signal('EASY_BEHAVIOR')
        newrec['easy_trials'].chans = ['easy_trials']
    if 'hard_trials' in state_signals:
        newrec['hard_trials'] = resp.epoch_to_signal('HARD_BEHAVIOR')
        newrec['hard_trials'].chans = ['hard_trials']
    if ('active' in state_signals) or ('far' in state_signals):
        newrec['active'] = resp.epoch_to_signal('ACTIVE_EXPERIMENT')
        newrec['active'].chans = ['active']
    if (('hit_trials' in state_signals) or ('miss_trials' in state_signals)
            or ('far' in state_signals) or ('hit' in state_signals)):
        newrec['hit_trials'] = resp.epoch_to_signal('HIT_TRIAL')
        newrec['miss_trials'] = resp.epoch_to_signal('MISS_TRIAL')
        newrec['fa_trials'] = resp.epoch_to_signal('FA_TRIAL')

    sm_len = 180 * newrec['resp'].fs
    if 'far' in state_signals:
        a = newrec['active'].as_continuous()
        fa = newrec['fa_trials'].as_continuous().astype(float)
        #c = np.concatenate((np.zeros((1,sm_len)), np.ones((1,sm_len+1))),
        #                   axis=1)
        c = np.ones((1, sm_len)) / sm_len

        fa = convolve2d(fa, c, mode='same')
        fa[a] -= 0.25  # np.nanmean(fa[a])
        fa[np.logical_not(a)] = 0

        s = newrec['fa_trials']._modified_copy(fa)
        s.chans = ['far']
        s.name = 'far'
        newrec.add_signal(s)

    if 'hit' in state_signals:
        a = newrec['active'].as_continuous()
        hr = newrec['hit_trials'].as_continuous().astype(float)
        ms = newrec['miss_trials'].as_continuous().astype(float)
        ht = hr - ms

        c = np.ones((1, sm_len)) / sm_len

        ht = convolve2d(ht, c, mode='same')
        ht[a] -= 0.1  # np.nanmean(ht[a])
        ht[np.logical_not(a)] = 0

        s = newrec['hit_trials']._modified_copy(ht)
        s.chans = ['hit']
        s.name = 'hit'
        newrec.add_signal(s)

    if 'lick' in state_signals:
        newrec['lick'] = resp.epoch_to_signal('LICK')

    # pupil interactions
    if ('p_x_a' in state_signals):
        # normalize min-max
        p = newrec["pupil"].as_continuous()
        a = newrec["active"].as_continuous()
        newrec["p_x_a"] = newrec["pupil"]._modified_copy(p * a)
        newrec["p_x_a"].chans = ["p_x_a"]

    if ('prw' in state_signals):
        # add channel two of the resp to state and delete it from resp
        if len(rec['resp'].chans) != 2:
            raise ValueError("this is for pairwise fitting")
        else:
            ch2 = rec['resp'].chans[1]
            ch1 = rec['resp'].chans[0]

        newrec['prw'] = newrec['resp'].extract_channels([ch2]).rasterize()
        newrec['resp'] = newrec['resp'].extract_channels([ch1]).rasterize()

    if ('pup_x_prw' in state_signals):
        # interaction term between pupil and the other cell
        if 'prw' not in newrec.signals.keys():
            raise ValueError("Must include prw alone before using interaction")

        else:
            pup = newrec['pupil']._data
            prw = newrec['prw']._data
            sig = newrec['pupil']._modified_copy(pup * prw)
            sig.name = 'pup_x_prw'
            sig.chans = ['pup_x_prw']
            newrec.add_signal(sig)

    for i, x in enumerate(state_signals):
        if x in permute_signals:
            # kludge: fix random seed to index of state signal in list
            # this avoids using the same seed for each shuffled signal
            # but also makes shuffling reproducible
            newrec = concatenate_state_channel(
                newrec,
                newrec[x].shuffle_time(rand_seed=i, mask=newrec['mask']),
                state_signal_name=new_signalname)
        else:
            newrec = concatenate_state_channel(
                newrec, newrec[x], state_signal_name=new_signalname)

        newrec = concatenate_state_channel(newrec,
                                           newrec[x],
                                           state_signal_name=new_signalname +
                                           "_raw")

    return newrec
Exemple #29
0
def average_away_epoch_occurrences(recording, epoch_regex='^STIM_'):
    '''
    Returns a recording with _all_ signals averaged across epochs that
    match epoch_regex, shortening them so that each epoch occurs only
    once in the new signals. i.e. unlike 'add_average_sig', the new
    recording will have signals 3x shorter if there are 3 occurrences of
    every epoch.

    This has advantages:
    1. Averaging the value of a signal (such as a response) in different
       occurrences will make it behave more like a linear variable with
       gaussian noise, which is advantageous in many circumstances.
    2. There will be less computation needed because the signal is shorter.

    It also has disadvantages:
    1. Stateful filters (FIR, IIR) will be subtly wrong near epoch boundaries
    2. Any ordering of epochs is essentially lost, unless all epochs appear
       in a perfectly repeated order.

    To avoid accidentally averaging away differences in responses to stimuli
    that are based on behavioral state, you may need to create new epochs
    (based on stimulus and behaviorial state, for example) and then match
    the epoch_regex to those.
    '''
    epochs = recording.epochs
    epoch_names = sorted(set(ep.epoch_names_matching(epochs, epoch_regex)))

    offset = 0
    new_epochs = []
    fs = recording[list(recording.signals.keys())[0]].fs
    d = int(np.ceil(np.log10(fs)) + 1)
    for epoch_name in epoch_names:
        common_epochs = ep.find_common_epochs(epochs, epoch_name, d=d)
        query = 'name == "{}"'.format(epoch_name)
        end = common_epochs.query(query).iloc[0]['end']
        common_epochs[['start', 'end']] += offset
        offset += end
        new_epochs.append(common_epochs)

    new_epochs = pd.concat(new_epochs, ignore_index=True)

    averaged_recording = recording.copy()

    for signal_name, signal in recording.signals.items():
        # TODO: this may be better done as a method in signal subclasses since
        # some subclasses may have more efficient approaches (e.g.,
        # TiledSignal)

        # Extract all occurances of each epoch, returning a dict where keys are
        # stimuli and each value in the dictionary is (reps X cell X bins)
        epoch_data = signal.rasterize().extract_epochs(epoch_names)

        # Average over all occurrences of each epoch
        for epoch_name, epoch in epoch_data.items():
            # TODO: fix empty matrix error. do epochs align properly?
            if np.sum(np.isfinite(epoch)):
                epoch_data[epoch_name] = np.nanmean(epoch, axis=0)
            else:
                epoch_data[epoch_name] = epoch[0, ...]
        data = [epoch_data[epoch_name] for epoch_name in epoch_names]
        data = np.concatenate(data, axis=-1)
        if data.shape[-1] != round(signal.fs * offset):
            raise ValueError('Misalignment issue in averaging signal')

        averaged_signal = signal._modified_copy(data, epochs=new_epochs)
        averaged_recording.add_signal(averaged_signal)


#        # TODO: Eventually need a smarter check for this incase it's named
#        #       something else. Basically just want to preserve spike data.
#        if signal.name == 'resp':
#            spikes = signal.copy()
#            spikes.name = signal.name + ' spikes'
#            averaged_recording.add_signal(spikes)

    return averaged_recording
Exemple #30
0
def spo_dstrf_per_stream_condition(n_pc=2,
                                   memory=12,
                                   recname='val',
                                   cellids=None,
                                   **ctx):

    rec = ctx[recname].apply_mask()
    modelspec = ctx['modelspec']
    print(ctx['modelspec'].meta['modelname'])
    print(ctx['modelspec'].meta['cellid'])

    if cellids is None:
        # analyze all output channels
        cellids = rec['resp'].chans
        siteids = [c.split("-")[0] for c in cellids]

    out_channel = list(np.arange(len(cellids)))
    channel_count = len(out_channel)

    # figure out epoch bounds
    e = rec['resp'].epochs
    enames = ep.epoch_names_matching(e, '^STIM_')

    set_names = ['stream 1', 'stream 2', 'coh', 'inc']
    esets = [['STIM_T+si464+null', 'STIM_T+si516+null'],
             ['STIM_T+null+si464', 'STIM_T+null+si516'],
             ['STIM_T+si464+si464', 'STIM_T+si516+si516'],
             ['STIM_T+si464+si516', 'STIM_T+si516+si464']]
    index_sets = []
    for _es in esets:
        this_index = np.array([], dtype=int)
        for e in _es:
            x = rec['resp'].get_epoch_indices(e)
            print(f'{e}: {x}')
            this_index = np.concatenate(
                (this_index, np.arange(x[0, 0], x[0, 1], dtype=int)))
        index_sets.append(this_index)

    pcs = [''] * 4
    pc_mag = [''] * 4

    for s in range(4):
        index_range = index_sets[s]

        # skip silent bins
        stim_mag = rec['stim'].as_continuous().sum(axis=0)
        stim_big = stim_mag > np.max(stim_mag) / 1000
        index_range = index_range[(index_range > memory)
                                  & stim_big[index_range.astype(int)]]
        print(
            f'Calculating dstrf for {channel_count} channels, {len(index_range)} timepoints, memory={memory}'
        )

        pcs[s], pc_mag[s] = dstrf_pca(modelspec,
                                      rec,
                                      pc_count=n_pc,
                                      out_channel=out_channel,
                                      index_range=index_range,
                                      memory=memory)

    f2, axs = plt.subplots(8, 10, figsize=(20, 12))
    cmax = np.min([channel_count, 10])
    for c in range(cmax):
        cellid = cellids[c]
        for s in range(4):
            for i in range(2):
                mm = np.max(np.abs(pcs[s][i, :, :, c]))
                _p = pcs[s][i, :, :, c] * pc_mag[s][i, c] / pc_mag[s][0, c]
                _p *= np.sign(_p.sum())
                _row = s * 2 + i
                _col = c
                axs[_row, _col].imshow(_p,
                                       aspect='auto',
                                       origin='lower',
                                       clim=[-mm, mm])
                if i + s == 0:
                    axs[_row, _col].set_title(f'{cellid} {pc_mag[s][i,c]:.3f}',
                                              fontsize=8)
                else:
                    axs[_row, _col].set_title(f'{pc_mag[s][i,c]:.3f}',
                                              fontsize=8)
                if _col < 7:
                    axs[_row, _col].set_xticks([])
                if c > 0:
                    axs[_row, _col].set_yticks([])
                if (c == 0):
                    axs[_row, _col].set_ylabel(set_names[s])
                ax_remove_box(axs[_row, _col])
    return f2, pcs, pc_mag