def get_data(file_name):
    reader = nexfile.Reader(useNumpy=True)
    data = reader.ReadNex5File(file_name)
    myNex = my_nex_class(file_name)
    spike_data = myNex.grab_spike_data()
    spike_names = myNex.grab_spike_names()
    cont_data = myNex.grab_cont_data()
    cont_names = myNex.grab_cont_names()
    firing_rate = myNex.bin_spike_data(0.05)
    spike = myNex.smooth_firing_rate(firing_rate, 0.05)
    kin_p = myNex.cont_downsample(1/0.05)
    return spike, kin_p
Ejemplo n.º 2
0
 def __init__(self, file_name):
     reader = nexfile.Reader(useNumpy=True)
     self.data = reader.ReadNex5File(file_name)
     self.total_time = self.data['FileHeader']['End']
Ejemplo n.º 3
0
"""

import nexfile as nex
import numpy as np
import os
import pandas as pd
import matplotlib.pyplot as plt
import glob
from scipy import stats

#%%
pathname=r'C:\Users\Qixin\XuChunLab\nexdata\BaiTao\191545\06072019'
date=os.path.split(pathname)[1]
animal=os.path.split(os.path.split(pathname)[0])[1]
filepath=glob.glob(os.path.join(pathname,'*.nex'))[0]
nexin=nex.Reader(useNumpy=True).ReadNexFile(filepath)
neurons=[]
waveforms=[]
events=[]
markers=[]
for var in nexin['Variables']:
    if var['Header']['Type']==0:
        neurons.append(var)
        #print('neuron',len(neurons))
    if var['Header']['Type']==1:
        events.append(var)
        #print('events',len(events))
    if var['Header']['Type']==3:
        waveforms.append(var)
        #print('waveforms',len(waveforms))
    if var['Header']['Type'] == 6 and len(var['Timestamps']) != 0:
Ejemplo n.º 4
0
def main(filename, prevTime, numStims):

    reader = nexfile.Reader(useNumpy=True)
    fData = reader.ReadNexFile(filename)

    nvar, names, types = nex_info(fData)

    #%% spikets
    neuronid = names[np.where(types == 0)].tolist()
    numNeurons = len(neuronid)

    spikets = []
    for i in np.arange(numNeurons):
        ts = nex_ts(fData, neuronid[i])
        spikets.append(ts)

#%% markerts, markervals
    Strobed_chIdx = np.where(names == 'Strobed')[0][0]
    markerts = fData['Variables'][Strobed_chIdx]['Timestamps']
    markervals = np.array(fData['Variables'][Strobed_chIdx]['Markers'],
                          dtype=int)[0]

    #%% get parseParams
    parseParams = get_parseParams()

    #%% photodiode from CRT
    """
    ts = nex_ts(fData, parseParams['pdiodeChannel']);
    pdAboveDists = ts[1:] - ts[:-1];
    pdOnTS = np.append(ts[0],
                       ts[np.where(pdAboveDists > parseParams['pdiodeDistanceThreshold'])[0]+1]);
    pdOffTS = np.append(ts[np.where(pdAboveDists > parseParams['pdiodeDistanceThreshold'])[0]],ts[-1]);
    """
    #%% photodiode from LCD
    adfreq, ts, fn, d = nex_cont(fData, 'AD14')
    ## different depending on Rigs
    d2 = d[1:] - d[:-1]
    ts_old = ts[:]
    ts = []
    for i in np.arange(len(fn)):
        ts = np.append(ts,
                       np.linspace(0, fn[i] / adfreq, fn[i]) + ts_old[i])

    pdOnTS_raw = ts[np.where(d2 > 0.2)[0] + 1]
    pdOffTS_raw = ts[np.where(d2 < -0.2)[0] + 1]

    pdOn_dist = pdOnTS_raw[1:] - pdOnTS_raw[:-1]
    pdOnTS = np.append(pdOnTS_raw[0],
                       pdOnTS_raw[np.where(pdOn_dist > 0.02)[0] + 1])

    pdOff_dist = pdOffTS_raw[1:] - pdOffTS_raw[:-1]
    pdOffTS = np.append(pdOffTS_raw[0],
                        pdOffTS_raw[np.where(pdOff_dist > 0.02)[0] + 1])

    del pdOnTS_raw, pdOffTS_raw, pdOn_dist, pdOff_dist

    #%% Get experiment parameters (Task parameters) sent from Pype to Plexon
    counter = 0
    experiment = dict()
    postTime = prevTime
    experiment['iti_start'] = []
    experiment['iti_end'] = []
    experiment['numNeurons'] = numNeurons
    experiment['neuronid'] = neuronid
    experiment['prevTime'] = prevTime
    experiment['postTime'] = postTime

    if markervals[counter] != parseParams['rfxCode']:
        print('markerval #' + str(counter) + ' was not rfxCode')
    else:
        experiment['rfx'] = markervals[counter + 1] - parseParams['yOffset']
        counter = counter + 2

    if markervals[counter] != parseParams['rfyCode']:
        print('markerval #' + str(counter) + ' was not rfyCode')
    else:
        experiment['rfy'] = markervals[counter + 1] - parseParams['yOffset']
        counter = counter + 2

    if markervals[counter] != parseParams['stimWidthCode']:
        print('markerval #' + str(counter) + ' was not stimWidthCode')
    else:
        experiment['StimSize'] = markervals[counter +
                                            1] - parseParams['yOffset']
        counter = counter + 2

    if markervals[counter] != parseParams['itiCode']:
        print('markerval #' + str(counter) + ' was not itiCode')
    else:
        experiment['iti'] = markervals[counter + 1]
        counter = counter + 2

    if markervals[counter] != parseParams['stim_timeCode']:
        print('markerval #' + str(counter) + ' was not stim_timeCode')
    else:
        experiment['StimDur'] = markervals[counter + 1]
        counter = counter + 2

    if markervals[counter] != parseParams['isiCode']:
        print('markerval #' + str(counter) + ' was not isiCode')
    else:
        experiment['isi'] = markervals[counter + 1]
        counter = counter + 2

    if markervals[counter] != parseParams[
            'stim_numCode']:  # Number of stimuli per trial
        print('markerval #' + str(counter) + ' was not stim_numCode')
    else:
        #experiment['stim_num'] = markervals[counter+1];
        counter = counter + 2

#%% Prepare StimStructs
    stimStructs = []
    for i in np.arange(numStims):
        stimStructs.append(dict())
        stimStructs[i]['numInstances'] = 0
        stimStructs[i]['timeOn'] = []
        stimStructs[i]['timeOff'] = []
        stimStructs[i]['pdOn'] = []
        stimStructs[i]['pdOff'] = []
        stimStructs[i]['neurons'] = []
        for j in np.arange(numNeurons):
            stimStructs[i]['neurons'].append(dict())

#%% Prepare to get stimulus information parameters
    stimITIOns = np.where(markervals == parseParams['startITICode'])[0]
    if stimITIOns[0] != counter:
        print('The first start_iti code is offset')
    stimOns = np.where(markervals == parseParams['stimOnCode'])[0]

    error_indices = []
    completedITIs = 0

    #%% Get stimuli
    for i in np.arange(len(stimITIOns) - 1):  # the file should end with
        # a startITI that we don't care about
        if stimITIOns[i] < counter:
            continue

        index = stimITIOns[i] + 1
        next_code = markervals[index]

        if next_code == parseParams['endITICode']:
            experiment['iti_start'].append(markerts[stimITIOns[i]])
            experiment['iti_end'].append(markerts[index])
            completedITIs = completedITIs + 1

        elif next_code == parseParams['pauseCode']:
            if markervals[index + 1] != parseParams['unpauseCode']:
                print('Found pause, but no unpause at ' + str(index + 1))
                print('continuing from next start_iti')
                error_indices.append(index)
                continue
            index = index + 2
            next_code = markervals[index]

            if next_code == parseParams['endITICode']:
                experiment['iti_start'].append(markerts[stimITIOns[i]])
                experiment['iti_end'].append(markerts[index])
                completedITIs = completedITIs + 1
            else:
                print('Found bad code ' + str(next_code) +
                      ' after start_iti at index ' + str(index))
                print('continuing from next start_iti')
                error_indices.append(index)
                continue

        else:
            print('Found bad code ' + str(next_code) +
                  ' after start_iti at index ' + str(index))
            print('continuing from next start_iti')
            error_indices.append(index)
            continue

        next_code2 = markervals[index + 1]
        if next_code2 == parseParams['fixAcquiredCode']:
            pass
        elif next_code2 == parseParams['UninitiatedTrialCode']:
            if markervals[index + 2] != parseParams['startITICode']:
                error_indices.append(index + 2)
                print('Found non start_iti code ' +
                      str(markervals[index + 2]) +
                      ' after Uninitiated trial at ' + str(index + 2))
            continue
        else:
            print('Found bad code ' + str(next_code2) +
                  ' after end_iti at index ' + str(stimITIOns[i] + 2))
            error_indices.append(index)
            continue

        ndex = index + 2
        trialCode = []

        while (trialCode == []):
            stimTimeCodeToStore = []
            optionalCode = 0

            stimCode = markervals[ndex + optionalCode]

            if stimCode == parseParams['fixLost']:
                if hasValidBreakFix(ndex + optionalCode, markervals,
                                    parseParams):
                    trialCode = parseParams['breakFixCode']
                    continue
            elif stimCode != parseParams['stimIDCode']:
                print('Found ' + str(stimCode) +
                      ' as a stimID or breakfix code at stim time ' +
                      str(markerts[ndex + optionalCode]) + ' at index ' +
                      str(ndex + optionalCode))
                print('continuing from next start_iti')
                error_indices.append(ndex + optionalCode)
                trialCode = parseParams['codeError']
                continue

            if markervals[ndex + 1 + optionalCode] == parseParams['fixLost']:
                if hasValidBreakFix(ndex + 1 + optionalCode, markervals,
                                    parseParams):
                    trialCode = parseParams['breakFixCode']
                    continue
            elif ((markervals[ndex + 1 + optionalCode] >=
                   parseParams['stimIDOffset'])
                  and (markervals[ndex + 1 + optionalCode] <
                       parseParams['stimRotOffset'])):
                stimIDCodeToStore = markervals[ndex + 1 + optionalCode]
            else:
                print('Found ' + str(markervals[ndex + 1]) +
                      ' as a stimulus code at stim time ' +
                      str(markerts[ndex + 1 + optionalCode]) + ' at index ' +
                      str(ndex + 1 + optionalCode))
                print('continuing from next start_iti')
                error_indices.append(ndex + optionalCode + 1)
                trialCode = parseParams['codeError']
                continue

            ## next code is either fixlost or stimOn
            codeIndex = ndex + 2 + optionalCode
            code = markervals[codeIndex]
            if code == parseParams['fixLost']:
                if hasValidBreakFix(codeIndex, markervals, parseParams):
                    trialCode = parseParams['breakFixCode']
                    continue
            elif code != parseParams['stimOnCode']:
                print('Missing StimOn or fixlost code, found ' + str(code) +
                      ' at ' + str(codeIndex))
                print('continuing from next start_iti')
                error_indices.append(codeIndex)
                trialCode = parseParams['codeError']
                continue
            else:
                stimOnTime = markerts[codeIndex]

            ## next code is either fixlost or stimOff
            codeIndex = ndex + 3 + optionalCode
            code = markervals[codeIndex]
            if code == parseParams['fixLost']:
                if hasValidBreakFix(codeIndex, markervals, parseParams):
                    trialCode = parseParams['breakFixCode']
                    continue
            elif code != parseParams['stimOffCode']:
                print('Missing StimOff or fixlost code, found ' + str(code) +
                      ' at ' + str(codeIndex))
                print('continuing from next start_iti')
                error_indices.append(codeIndex)
                trialCode = parseParams['codeError']
                continue
            else:
                stimOffTime = markerts[codeIndex]

            ## having made it here, we can now call this a completed stimulus presentation and record the results
            sIndex = stimIDCodeToStore - parseParams['stimIDOffset']
            sIndex = sIndex - 1
            # for zero-based indexing in python (Matlab doesn't need this);
            if stimStructs[sIndex]['numInstances'] == []:
                stimStructs[sIndex]['numInstances'] = 1
            else:
                stimStructs[sIndex][
                    'numInstances'] = stimStructs[sIndex]['numInstances'] + 1

            inst = stimStructs[sIndex]['numInstances']
            inst = inst - 1
            # for zero-based indexing in python (Matlab doesn't need this);
            stimStructs[sIndex]['timeOn'].append(stimOnTime)
            stimStructs[sIndex]['timeOff'].append(stimOffTime)

            ## now find the pdiode events associated with
            pdOnsAfter = np.where(pdOnTS > stimOnTime)[0]
            if len(pdOnsAfter) == 0:
                print(
                    'Error, did not find a photodiode on code after stimon at time '
                    + str(stimOnTime))
                print('Ignoring... Continuing')
            else:
                pdOffsAfter = np.where(pdOffTS > pdOnTS[pdOnsAfter[0]])[0]
                if len(pdOffsAfter) == 0:
                    print(
                        'Error, did not find a photodiode on code after stimon at time '
                        + str(pdOnTS[pdOnsAfter[0]]))
                    print('Ignoring... Continuing')
                else:
                    stimStructs[sIndex]['pdOn'].append(pdOnTS[pdOnsAfter[0]])
                    stimStructs[sIndex]['pdOff'].append(
                        pdOffTS[pdOffsAfter[0]])

            ## now get neural data
            for j in np.arange(numNeurons):
                mySpikes = np.array([])
                if stimStructs[sIndex]['pdOff'] != []:
                    spikeIndices = np.where(
                        (spikets[j] >= (stimOnTime - prevTime))
                        & (spikets[j] <=
                           (stimStructs[sIndex]['pdOff'][inst] + postTime)))[0]
                else:
                    spikeIndices = np.where(
                        (spikets[j] >= (stimOnTime - prevTime))
                        & (spikets[j] <= (stimOffTime + postTime)))[0]

                if len(spikeIndices) > 0:
                    mySpikes = np.append(mySpikes, spikets[j][spikeIndices])
                if inst == 0:
                    stimStructs[sIndex]['neurons'][j]['spikes'] = []
                stimStructs[sIndex]['neurons'][j]['spikes'].append(mySpikes)

            ## next code is either fixlost, an object code or correct_response
            codeIndex = ndex + 4 + optionalCode
            code = markervals[codeIndex]

            if code == parseParams['fixLost']:
                if hasValidBreakFix(codeIndex, markervals, parseParams):
                    trialCode = parseParams['breakFixCode']
                    continue
            elif code == parseParams['correctCode']:  # end of trial
                if markervals[codeIndex + 1] != parseParams['startITICode']:
                    print('Missing startITI after ' +
                          str(markervals[codeIndex + 1]) + ' at ' +
                          str(markerts[codeIndex + 1]) + ' at index ' +
                          str(codeIndex + 1))
                    error_indices.append(codeIndex)
                    trialCode = parseParams['codeError']
                    continue
                else:
                    trialCode = parseParams['correctCode']
                    continue
            elif code != parseParams['stimIDCode']:
                print('Found ' + str(stimCode) +
                      ' as a stim ID code at stim time ' +
                      str(markerts[codeIndex]))
                print('continuing from next start_iti')
                error_indices.append(codeIndex)
                trialCode = parseParams['codeError']
                continue
            else:
                ndex = ndex + 4 + optionalCode

#%% add stimStructs to experiment output, then return
    experiment['stimStructs'] = stimStructs
    experiment['errors'] = error_indices

    return experiment
import nexfile
import numpy as np
import scipy.sparse as ss

reader = nexfile.Reader()
fileData = reader.ReadNexFile(
    'C:\\Users\\siddh\\PycharmProjects\\\CNN_RodentData\\1065\\1065u065merge-clean.nex'
)
dictlist = []
for key, value in dict.items(fileData):
    temp = [value]
    dictlist.append(temp)
temp = str(dictlist[0])
temp1 = temp.split('\'End\':')
temp2 = temp1[1].split(',')
print("Total Recording time:" + str(temp2[0]) + " sec" + "   " +
      str(int((float(temp2[0]) * 1000))) + " msec")

vars = dictlist[1]
var_str = str(vars).split("]},")
all_val = []
k = 0
for j in range(len(var_str)):
    temp = var_str[j].split('Name\': \'')
    name = temp[1].split('\', \'DataOffset\':')
    if (name[0].__contains__("CA1")):
        ms = np.zeros((int((float(temp2[0]) * 1000))))
        k = k + 1
        temp = name[1].split('\'Timestamps\': [')
        values = temp[1].split(', ')
        #print(values)
Ejemplo n.º 6
0
def buildneurons(pathname=r'C:\Users\Qixin\XuChunLab\nexdata\192043',
                 file_type='nex',
                 build_tracking=False,
                 arena_reset=False,
                 body_part='Body',
                 bootstrap=False):
    #import nex file with a GUI window
    experiment_date = os.path.split(pathname)[1]
    animal_id = os.path.split(os.path.split(pathname)[0])[1]
    if build_tracking:
        #build position
        pos = bt.build_pos(pathname,
                           reset_arena=arena_reset,
                           body_part=body_part)
    else:
        pos = []
    ensemble = []

    if file_type == 'nex':
        filepath = glob.glob(os.path.join(pathname, '*.nex'))[0]
        nexin = nex.Reader(useNumpy=True).ReadNexFile(filepath)
        neurons = []
        waveforms = []
        events = []
        markers = []
        for var in nexin['Variables']:
            if var['Header']['Type'] == 0:
                neurons.append(var)
                #print('neuron',len(neurons))
            if var['Header']['Type'] == 1:
                events.append(var)
                #print('events',len(events))
            if var['Header']['Type'] == 3:
                waveforms.append(var)
                #print('waveforms',len(waveforms))
            if var['Header']['Type'] == 6 and len(var['Timestamps']) != 0:
                markers.append(var)
                #print('markers',len(markers))
        #ask for user input of context protocol
        try:
            suppm = pd.read_excel(glob.glob(os.path.join(pathname, '*csv'))[0])
            protocol = (suppm['order'])
            protocol = protocol[~protocol.isnull()]
            input_protocol = protocol.values
        except:
            input_protocol = [
                str(x)
                for x in input('Enter the order of context protocol: ').split(
                ) or 'A B A B A B B A'.split()
            ]
            print(input_protocol)
        record_marker = [events[0]['Timestamps'], events[1]['Timestamps']]
        door_marker = []
        for mrker in markers:
            if mrker['Header']['Name'] == 'KBD1':
                door_marker.insert(0, mrker['Timestamps'])
            elif mrker['Header']['Name'] == 'KBD3':
                door_marker.insert(1, mrker['Timestamps'])
        door_marker = door_marker[0:2]
        allmarker = Marker(record_marker, door_marker, input_protocol)
        for i in range(1, len(neurons)):
            ensemble.append(
                Unit(
                    neurons[i]['Timestamps'], allmarker,
                    Waveform(waveforms[i]['Timestamps'],
                             waveforms[i]['WaveformValues'],
                             waveforms[i]['Header']['SamplingRate']), pos,
                    neurons[i]['Header']['Name'], experiment_date, animal_id,
                    bootstrap))
    elif file_type == 'mat':
        filepath = sorted(glob.glob(os.path.join(pathname, '*.mat')))
        for matfile in filepath:
            data = scipy.io.loadmat(matfile)
            record_marker = [
                data['record_start'].T[0], data['record_end'].T[0]
            ]
            door_marker = [data['door_open'].T[0], data['door_close'].T[0]]
            input_protocol = []
            for c in data['input_protocol'][0]:
                if c == 1:
                    input_protocol.append('A')
                elif c == 2:
                    input_protocol.append('B')
            allmarker = Marker(record_marker, door_marker, input_protocol)
            ensemble.append(
                Unit(
                    data['Timestamps'], allmarker,
                    Waveform(data['Timestamps'], data['WaveformValues'],
                             data['SamplingRate']), pos, data['name'][0],
                    bootstrap, experiment_date, animal_id))

    return ensemble, pos
Ejemplo n.º 7
0
    return filtered_EMG


def EMG_downsample(new_fs, data):
    fs = 2010
    data = np.asarray(data).T
    n = int(np.floor(fs / new_fs))
    new_data = np.empty((0, np.size(data, 1)))
    for i in range(1, int(np.size(data, 0) / n) + 1):
        new_data = np.vstack((new_data, data[i * n, :]))
    return new_data


file_name = "Z:/data/Greyson_17L2/NeuroexplorerFile/20180831_Greyson_PG_003.nex5"

reader = nexfile.Reader(useNumpy=True)
data = reader.ReadNex5File(file_name)
myNex = my_nex_class(file_name)
spike_data = myNex.grab_spike_data()
spike_names = myNex.grab_spike_names()
firing_rate = myNex.bin_spike_data(0.05)
spike = myNex.smooth_firing_rate(firing_rate, 0.05)
spike = spike[:, 5:]

EMG_list = {
    'FCR1', 'FCR2', 'FCU1', 'FCU2', 'FDP2', 'FDP3', 'FDS1', 'FDS2', 'FPB',
    'PT', 'LUM', 'FDI', 'EDC3'
}
ch1 = [28, 25, 26, 27, 24, 0, 30, 31, 3, 1, 4, 5, 6]
ch2 = [19, 22, 21, 20, 23, 15, 17, 16, 12, 14, 11, 10, 9]
Ejemplo n.º 8
0
def data_prep(
        pathname=r'C:\Users\Qixin\XuChunLab\nexdata\BaiTao\191545\06052019'):

    date = os.path.split(pathname)[1]
    animal = os.path.split(os.path.split(pathname)[0])[1]
    filepath = glob.glob(os.path.join(pathname, '*.nex'))[0]
    nexin = nex.Reader(useNumpy=True).ReadNexFile(filepath)
    neurons = []
    waveforms = []
    events = []
    markers = []
    for var in nexin['Variables']:
        if var['Header']['Type'] == 0:
            neurons.append(var)
            #print('neuron',len(neurons))
        if var['Header']['Type'] == 1:
            events.append(var)
            #print('events',len(events))
        if var['Header']['Type'] == 3:
            waveforms.append(var)
            #print('waveforms',len(waveforms))
        if var['Header']['Type'] == 6 and len(var['Timestamps']) != 0:
            markers.append(var)
    #%%
    marker_dict = {
        'Lick left': 'EVT09',
        'Lick right': 'EVT10',
        'Enter context': 'EVT11',
        'Pump right': 'EVT05',
        'Pump left': 'EVT06',
        'Go cue': 'EVT07'
    }
    marker_ts = {}
    for mrker in markers:
        mrker_name = mrker['Header']['Name']
        mrker_ts = mrker['Timestamps']
        for mrker_dict_keys, mrker_dict_value in marker_dict.items():
            if mrker_dict_value == mrker_name:
                marker_ts[mrker_dict_keys] = mrker_ts
    marker_ts['Trial start'] = events[0]['Timestamps']
    marker_ts['Trial end'] = events[1]['Timestamps']
    #%% set up marker dataframes

    trial_data = {
        'trial start': [],
        'door open': [],
        'enter context': [],
        'enter context2': [],
        'go cue': [],
        'lick time': [],
        'lick choice': [],
        'pump': [],
        'trial end': [],
        'correctness': [],
        'context': [],
        'lick delay': []
    }
    Units = []
    for neu in neurons:
        Units.append({'name': neu['Header']['Name'], 'spkt': []})

    for t, t_start in enumerate(marker_ts['Trial start']):
        t_end = marker_ts['Trial end'][t]
        trial_data['trial start'].append(t_start)
        trial_data['trial end'].append(t_end)
        trial_data['door open'].append(t_start + 1)
        try:
            trial_data['enter context2'].append(
                slice_ts(marker_ts['Enter context'], t_start, t_end)[0])
        except:
            trial_data['enter context2'].append(np.nan)
        trial_data['go cue'].append(
            slice_ts(marker_ts['Go cue'], t_start, t_end))
        trial_data['enter context'].append(trial_data['go cue'][t] - 1.5)
        lickl = slice_ts2(marker_ts['Lick left'], t_start, t_end)
        lickr = slice_ts2(marker_ts['Lick right'], t_start, t_end)
        lick_compare = [np.min(lickl), np.min(lickr)]
        #we estimated that there was a 0.2s delay of the lick signal
        trial_data['lick time'].append(np.min(lick_compare) + 0.2)
        try:
            trial_data['lick delay'].append(trial_data['lick time'][t] -
                                            trial_data['go cue'][t])
        except:
            trial_data['lick delay'].append(np.nan)
        trial_data['lick choice'].append(np.argmin(lick_compare))
        pumpl = slice_ts2(marker_ts['Pump left'], t_start, t_end)
        pumpr = slice_ts2(marker_ts['Pump right'], t_start, t_end)
        if ~np.isnan(
                trial_data['go cue']
            [t]):  #correctness: 0 is false, 1 is true; context: 0 is A, 1 is B
            if ~np.isinf(pumpl):  #correctA
                trial_data['pump'].append('left')
                trial_data['correctness'].append(1)
                trial_data['context'].append(0)
            elif ~np.isinf(pumpr):  #correctB
                trial_data['pump'].append('right')
                trial_data['correctness'].append(1)
                trial_data['context'].append(1)
            else:  #no reward
                if np.argmin(lick_compare
                             ) == 0:  #incorrectA: licked left but no reward
                    trial_data['pump'].append(np.nan)
                    trial_data['correctness'].append(0)
                    trial_data['context'].append(1)
                elif np.argmin(lick_compare
                               ) == 1:  #incorrectB: licked right but no reward
                    trial_data['pump'].append(np.nan)
                    trial_data['correctness'].append(0)
                    trial_data['context'].append(0)
                else:  #no lick: miss trial
                    trial_data['pump'].append(np.nan)
                    trial_data['correctness'].append(np.nan)
                    trial_data['context'].append(np.nan)
        else:  #invalid case
            trial_data['pump'].append(np.nan)
            trial_data['correctness'].append(np.nan)
            trial_data['context'].append(np.nan)
        for n, neu in enumerate(neurons):
            Units[n]['spkt'].append(slice_ts(neu['Timestamps'], t_start,
                                             t_end))

    trial_df = pd.DataFrame(data=trial_data)
    print('done preparing data')
    #trial_df=trial_df.dropna(subset=['go cue']) #drop the invalid trials
    return trial_df, Units