コード例 #1
0
def populatebehavior(paralel = True,drop_last_session_for_mice_in_training = True):
    print('adding behavior experiments')
    if paralel:
        #ray.init()
        result_ids = []
        #%%
        IDs = {k: v for k, v in zip(*lab.WaterRestriction().fetch('water_restriction_number', 'subject_id'))}
        df_surgery = pd.read_csv(dj.config['locations.metadata_behavior']+'Surgery.csv')
        for subject_now,subject_id_now in zip(IDs.keys(),IDs.values()): # iterating over subjects and removing last session     
            if subject_now in df_surgery['ID'].values and drop_last_session_for_mice_in_training == True and df_surgery['status'][df_surgery['ID']==subject_now].values[0] != 'sacrificed': # the last session is deleted only if the animal is still in training..
                print(df_surgery['status'][df_surgery['ID']==subject_now].values[0])
                if len((experiment.Session() & 'subject_id = "'+str(subject_id_now)+'"').fetch('session')) > 0:
                    sessiontodel = np.max((experiment.Session() & 'subject_id = "'+str(subject_id_now)+'"').fetch('session'))
                    session_todel = experiment.Session() & 'subject_id = "' + str(subject_id_now)+'"' & 'session = ' + str(sessiontodel)
                    dj.config['safemode'] = False
                    print('deleting last session of ' + subject_now)
                    session_todel.delete()
                    dj.config['safemode'] = True   
                    #%%
        for subject_now,subject_id_now in zip(IDs.keys(),IDs.values()): # iterating over subjects                       
            dict_now = dict()
            dict_now[subject_now] = subject_id_now
            result_ids.append(populatebehavior_core(dict_now))
            
        #ray.get(result_ids)
        #ray.shutdown()
    else:
        arguments = {'display_progress' : True}
        populatebehavior_core(arguments)
コード例 #2
0
    def make(self, key):
        '''
        Ephys .make() function
        '''

        log.info('EphysIngest().make(): key: {k}'.format(k=key))

        #
        # Find corresponding BehaviorIngest
        #
        # ... we are keying times, sessions, etc from behavior ingest;
        # so lookup behavior ingest for session id, quit with warning otherwise
        #

        try:
            behavior = (behavior_ingest.BehaviorIngest() & key).fetch1()
        except dj.DataJointError:
            log.warning('EphysIngest().make(): skip - behavior ingest error')
            return

        log.info('behavior for ephys: {b}'.format(b=behavior))

        #
        # Find Ephys Recording
        #
        key = (experiment.Session & key).fetch1()
        sinfo = ((lab.WaterRestriction() * lab.Subject().proj() *
                  experiment.Session()) & key).fetch1()

        rigpath = EphysDataPath().fetch1('data_path')
        h2o = sinfo['water_restriction_number']
        date = key['session_date'].strftime('%Y%m%d')

        dpath = pathlib.Path(rigpath, h2o, date)
        dglob = '[0-9]/{}'  # probe directory pattern

        v3spec = '{}_*_jrc.mat'.format(h2o)
        # old v3spec = '{}_g0_*.imec.ap_imec3_opt3_jrc.mat'.format(h2o)
        v3files = list(dpath.glob(dglob.format(v3spec)))

        v4spec = '{}_*.ap_res.mat'.format(h2o)
        # old v4spec = '{}_g0_*.imec?.ap_res.mat'.format(h2o)  # TODO v4ify
        v4files = list(dpath.glob(dglob.format(v4spec)))

        if (v3files and v4files) or not (v3files or v4files):
            log.warning(
                'Error - v3files ({}) + v4files ({}). Skipping.'.format(
                    v3files, v4files))
            return

        if v3files:
            files = v3files
            loader = self._load_v3

        if v4files:
            files = v4files
            loader = self._load_v4

        for f in files:
            self._load(loader(sinfo, rigpath, dpath, f.relative_to(dpath)))
コード例 #3
0
ファイル: helper_functions.py プロジェクト: ixcat/map-ephys
def subject2water(subject, session_num):
    water = (lab.WaterRestriction & {
        'subject_id': subject
    }).fetch('water_restriction_number')
    date = (experiment.Session() * lab.WaterRestriction & {
        'subject_id': subject,
        'session': session_num
    }).fetch('session_date')
    return water, date
コード例 #4
0
ファイル: helper_functions.py プロジェクト: ixcat/map-ephys
def water2subject(water, date):
    subject_id = (lab.WaterRestriction & {
        'water_restriction_number': water
    }).fetch('subject_id')
    session_num = (experiment.Session() * lab.WaterRestriction & {
        'water_restriction_number': water,
        'session_date': date
    }).fetch('session')
    return subject_id, session_num
コード例 #5
0
ファイル: publication.py プロジェクト: nwtien/map-ephys
    def retrieve1(self, key):
        '''
        retrieve related files for a given key
        '''

        # >>> list(key.keys())
        # ['subject_id', 'session', 'trial', 'electrode_group', 'globus_alia

        log.debug(key)
        lep, lep_sub, lep_dir = GlobusStorageLocation().local_endpoint
        log.info('local_endpoint: {}:{} -> {}'.format(lep, lep_sub, lep_dir))

        # get session related information needed for filenames/records
        sinfo = ((lab.WaterRestriction
                  * lab.Subject.proj()
                  * experiment.Session()
                  * experiment.SessionTrial) & key).fetch1()

        h2o = sinfo['water_restriction_number']
        sdate = sinfo['session_date']
        eg = key['electrode_group']
        trial = key['trial']

        # build file locations:
        # fpat: base file pattern for this sessions files
        # gbase: globus-url base path for this sessions files

        fpat = '{}_{}_{}_g0_t{}'.format(h2o, sdate, eg, trial)
        gbase = '/'.join((h2o, str(sdate), str(eg), fpat))

        repname, rep, rep_sub = (GlobusStorageLocation() & key).fetch()[0]

        gsm = self.get_gsm()
        gsm.activate_endpoint(lep)  # XXX: cache this / prevent duplicate RPC?
        gsm.activate_endpoint(rep)  # XXX: cache this / prevent duplicate RPC?

        sfxmap = {'.imec.ap.bin': ArchivedRawEphysTrial.ArchivedApChannel,
                  '.imec.ap.meta': ArchivedRawEphysTrial.ArchivedApMeta,
                  '.imec.lf.bin': ArchivedRawEphysTrial.ArchivedLfChannel,
                  '.imec.lf.meta': ArchivedRawEphysTrial.ArchivedLfMeta}

        for sfx, cls in sfxmap.items():
            if cls & key:
                log.debug('record found for {} & {}'.format(cls.__name__, key))
                gname = '{}{}'.format(gbase, sfx)

                srcp = '{}:/{}/{}'.format(rep, rep_sub, gname)
                dstp = '{}:/{}/{}'.format(lep, lep_sub, gname)

                log.info('transferring {} to {}'.format(srcp, dstp))

                # XXX: check if exists 1st? (manually or via API copy-checksum)
                if not gsm.cp(srcp, dstp):
                    emsg = "couldn't transfer {} to {}".format(srcp, dstp)
                    log.error(emsg)
                    raise dj.DataJointError(emsg)
コード例 #6
0
def load_all_sessions(subject_id):
    loader = get_loader()
    # ---- parse data dir and load all sessions ----
    """
        For each session, return 
        + subject_name
        + session_date
        + session_time
        + session_basename
        + session_files
        + username
        + rig
    """
    try:
        sessions_to_ingest = loader.load_sessions(subject_id)
    except FileNotFoundError as e:
        print(str(e))
        return

    # ---- work on each session ----
    for sess in sessions_to_ingest:
        session_files = sess.pop('session_files')

        if experiment.Session & sess:
            log.info(f'Session {sess} already exists. Skipping...')
            continue

        # ---- synthesize session number ----
        sess_num = (dj.U().aggr(experiment.Session()
                                & {
                                    'subject_id': subject_id
                                },
                                n='max(session)').fetch1('n') or 0) + 1
        # ---- insert ----
        sess_key = {**sess, 'session': sess_num}

        with dj.conn().transaction:
            experiment.Session.insert1(sess_key)
            InsertedSession.insert1(
                {
                    **sess_key, 'loader_method': loader.loader_name,
                    'sess_data_dir': session_files[0].parent.as_posix()
                },
                allow_direct_insert=True,
                ignore_extra_fields=True)
            InsertedSession.SessionFile.insert(
                [{
                    **sess_key, 'filepath': f.as_posix()
                } for f in session_files],
                allow_direct_insert=True,
                ignore_extra_fields=True)
            log.info(f'Inserted new session: {sess}')
コード例 #7
0
ファイル: temperature.py プロジェクト: xibby/pipeline
    def notify(self, key):
        ts, temperatures = (self & key).fetch1('temp_time', 'temperatures')

        import matplotlib.pyplot as plt
        fig = plt.figure(figsize=(10, 5))
        plt.plot(ts, temperatures)
        plt.ylabel('Temperature (C)')
        plt.xlabel('Seconds')
        img_filename = '/tmp/' + key_hash(key) + '.png'
        fig.savefig(img_filename)
        plt.close(fig)

        msg = 'temperature for {animal_id}-{session}-{scan_idx}'.format(**key)
        slack_user = notify.SlackUser() & (experiment.Session() & key)
        slack_user.notify(file=img_filename, file_title=msg)
コード例 #8
0
 def make(self, key):
     #%
     # key = { 'subject_id': 462149, 'session':1,'cell_number':1,'movie_number':11}
     session_time = (experiment.Session() & key).fetch('session_time')[0]
     cell_time = (ephys_patch.Cell() & key).fetch('cell_recording_start')[0]
     cell_sweep_start_times = (ephys_patch.Sweep()
                               & key).fetch('sweep_start_time')
     cell_sweep_end_times = (ephys_patch.Sweep()
                             & key).fetch('sweep_end_time')
     time_start = float(
         np.min(cell_sweep_start_times)) + cell_time.total_seconds(
         ) - session_time.total_seconds()
     time_end = float(
         np.max(cell_sweep_end_times)) + cell_time.total_seconds(
         ) - session_time.total_seconds()
     try:
         movie = (imaging.Movie()) & key & 'movie_start_time > ' + str(
             time_start) & 'movie_start_time < ' + str(time_end)
         sweep_start_times, sweep_end_times, sweep_nums = (
             ephys_patch.Sweep() & key).fetch('sweep_start_time',
                                              'sweep_end_time',
                                              'sweep_number')
         sweep_start_times = np.asarray(sweep_start_times,
                                        float) + cell_time.total_seconds(
                                        ) - session_time.total_seconds()
         sweep_end_times = np.asarray(sweep_end_times,
                                      float) + cell_time.total_seconds(
                                      ) - session_time.total_seconds()
         #for movie in movies_now:
         frametimes = (imaging.MovieFrameTimes
                       & movie).fetch1('frame_times')
         needed_start_time = frametimes[0]
         needed_end_time = frametimes[-1]
         sweep_nums_needed = sweep_nums[
             ((sweep_start_times > needed_start_time) &
              (sweep_start_times < needed_end_time)) |
             ((sweep_end_times > needed_start_time) &
              (sweep_end_times < needed_end_time)) |
             ((sweep_end_times > needed_end_time) &
              (sweep_start_times < needed_start_time))]
         if len(sweep_nums_needed) > 0:
             key['sweep_numbers'] = sweep_nums_needed
             self.insert1(key, skip_duplicates=True)
     except:
         pass
コード例 #9
0
def test_paths():
    rel = experiment.Session() * experiment.Scan.EyeVideo(
    ) * experiment.Scan.BehaviorFile().proj(hdf_file='filename')

    path_info = random.choice(rel.fetch.as_dict())

    tmp = path_info['hdf_file'].split('.')
    if '%d' in tmp[0]:
        # new version
        path_info['hdf_file'] = tmp[0][:-2] + '0.' + tmp[-1]
    else:
        path_info['hdf_file'] = tmp[0][:-1] + '0.' + tmp[-1]

    hdf_path = lab.Paths().get_local_path(
        '{behavior_path}/{hdf_file}'.format(**path_info))
    avi_path = lab.Paths().get_local_path(
        '{behavior_path}/{filename}'.format(**path_info))

    assert_true(os.path.isfile(avi_path) and os.path.isfile(hdf_path))
コード例 #10
0
ファイル: temperature.py プロジェクト: xibby/pipeline
    def session_plot(self):
        """ Do a plot of how temperature progress through a session"""
        import matplotlib.pyplot as plt
        import matplotlib.ticker as ticker

        # Check that plot is restricted to a single session
        session_key = self.fetch('KEY', limit=1)[0]
        session_key.pop('scan_idx')
        if len(self & session_key) != len(self):
            raise PipelineException(
                'Plot can only be generated for one session at a '
                'time')

        # Get times and timestamps, scan_ts
        scan_indices, ts, temperatures = self.fetch('scan_idx',
                                                    'temp_time',
                                                    'temperatures',
                                                    order_by='scan_idx')
        session_ts = (experiment.Session() & self).fetch1('session_ts')
        scan_ts = (experiment.Scan() & self).fetch('scan_ts',
                                                   order_by='scan_idx')
        abs_ts = [(sts - session_ts).seconds + (t - t[0])
                  for sts, t in zip(scan_ts, ts)]

        # Plot
        fig = plt.figure(figsize=(10, 5))
        for abs_ts_, temp_, scan_idx in zip(abs_ts, temperatures,
                                            scan_indices):
            plt.plot(abs_ts_ / 3600, temp_,
                     label='Scan {}'.format(scan_idx))  # in hours
        plt.title(
            'Temperature for {animal_id}-{session} starting at {session_ts}'.
            format(session_ts=session_ts, **session_key))
        plt.ylabel('Temperature (Celsius)')
        plt.xlabel('Hour')
        plt.legend()

        # Plot formatting
        plt.gca().yaxis.set_major_locator(ticker.MultipleLocator(0.5))
        plt.grid(linestyle='--', alpha=0.8)

        return fig
コード例 #11
0
ファイル: temperature.py プロジェクト: xibby/pipeline
    def _make_tuples(self, key):
        # Get behavior filename
        behavior_path = (experiment.Session() & key).fetch1('behavior_path')
        local_path = lab.Paths().get_local_path(behavior_path)
        filename = (experiment.Scan.BehaviorFile() & key).fetch1('filename')
        full_filename = os.path.join(local_path, filename)

        # Read file
        data = h5.read_behavior_file(full_filename)

        # Get counter timestamps and convert to seconds
        ts = h5.ts2sec(data['ts'], is_packeted=True)

        # Read temperature (if available) and invalidate points with unreliable timestamps
        temp_raw = data.get('temperature', None)
        if temp_raw is None:
            raise PipelineException(
                'Scan {animal_id}-{session}-{scan_idx} does not have '
                'temperature data'.format(**key))
        temp_raw[np.isnan(ts)] = float('nan')

        # Read temperature and smooth it
        temp_celsius = (temp_raw * 100 - 32) / 1.8  # F to C
        sampling_rate = int(round(
            1 / np.nanmedian(np.diff(ts))))  # samples per second
        smooth_temp = signal.low_pass_filter(temp_celsius,
                                             sampling_rate,
                                             cutoff_freq=1,
                                             filter_size=2 * sampling_rate)

        # Resample at 1 Hz
        downsampled_ts = ts[::sampling_rate]
        downsampled_temp = smooth_temp[::sampling_rate]

        # Insert
        self.insert1({
            **key, 'temp_time': downsampled_ts,
            'temperatures': downsampled_temp,
            'median_temperature': np.nanmedian(downsampled_temp)
        })
        self.notify(key)
コード例 #12
0
def plot_ephys_ophys_trace(key_cell,time_to_plot=None,trace_window = 1,show_stimulus = False,show_e_ap_peaks = False, show_o_ap_peaks = False):
    #%%
# =============================================================================
#     key_cell = {'session': 1,
#                 'subject_id': 456462,
#                 'cell_number': 3,
#                 'motion_correction_method': 'VolPy',
#                 'roi_type': 'VolPy'}
#     time_to_plot=None
#     trace_window = 100
#     show_stimulus = False
#     show_e_ap_peaks = True
#     show_o_ap_peaks = True
# =============================================================================
    
    
    fig=plt.figure()
    ax_ophys = fig.add_axes([0,0,2,.8])
    ax_ephys = fig.add_axes([0,-1,2,.8])
    if show_stimulus:
        ax_ephys_stim = fig.add_axes([0,-1.5,2,.4])
        
    #%
    session_time, cell_recording_start = (experiment.Session()*ephys_patch.Cell()&key_cell).fetch1('session_time','cell_recording_start')
    first_movie_start_time =  np.min(np.asarray(((imaging.Movie()*imaging_gt.GroundTruthROI())&key_cell).fetch('movie_start_time'),float))
    first_movie_start_time_real = first_movie_start_time + session_time.total_seconds()
    if not time_to_plot:
        time_to_plot = trace_window/2#ephys_matched_ap_times[0]
    session_time_to_plot = time_to_plot+first_movie_start_time  # time relative to session start
    cell_time_to_plot= session_time_to_plot + session_time.total_seconds() -cell_recording_start.total_seconds() # time relative to recording start
    #%
    sweep_start_times,sweep_end_times,sweep_nums = (ephys_patch.Sweep()&key_cell).fetch('sweep_start_time','sweep_end_time','sweep_number')
    needed_start_time = cell_time_to_plot - trace_window/2
    needed_end_time = cell_time_to_plot + trace_window/2
    #%
    sweep_nums = sweep_nums[((sweep_start_times > needed_start_time) & (sweep_start_times < needed_end_time)) |
                            ((sweep_end_times > needed_start_time) & (sweep_end_times < needed_end_time)) | 
                            ((sweep_end_times > needed_end_time) & (sweep_start_times < needed_start_time)) ]

    ephys_traces = list()
    ephys_trace_times = list()
    ephys_sweep_start_times = list()
    ephys_traces_stim = list()
    for sweep_num in sweep_nums:
        sweep = ephys_patch.Sweep()&key_cell&'sweep_number = %d' % sweep_num
        trace,sr= (ephys_patch.SweepMetadata()*ephys_patch.SweepResponse()&sweep).fetch1('response_trace','sample_rate')
        trace = trace*1000
        sweep_start_time  = float(sweep.fetch('sweep_start_time')) 
        trace_time = np.arange(len(trace))/sr + sweep_start_time + cell_recording_start.total_seconds() - first_movie_start_time_real
        
        trace_idx = (time_to_plot-trace_window/2 < trace_time) & (time_to_plot+trace_window/2 > trace_time)
                
        ax_ephys.plot(trace_time[trace_idx],trace[trace_idx],'k-')
        
        ephys_traces.append(trace)
        ephys_trace_times.append(trace_time)
        ephys_sweep_start_times.append(sweep_start_time)
        
        if show_e_ap_peaks:

            ap_max_index = (ephysanal.ActionPotential()&sweep).fetch('ap_max_index')
            aptimes = trace_time[np.asarray(ap_max_index,int)]
            apVs = trace[np.asarray(ap_max_index,int)]
            ap_needed = (time_to_plot-trace_window/2 < aptimes) & (time_to_plot+trace_window/2 > aptimes)
            aptimes  = aptimes[ap_needed]
            apVs  = apVs[ap_needed]
            ax_ephys.plot(aptimes,apVs,'ro')

        
        if show_stimulus:
            trace_stim= (ephys_patch.SweepMetadata()*ephys_patch.SweepStimulus()&sweep).fetch1('stimulus_trace')
            trace_stim = trace_stim*10**12
            ax_ephys_stim.plot(trace_time[trace_idx],trace_stim[trace_idx],'k-')
            ephys_traces_stim.append(trace_stim)
            
        
    ephysdata = {'ephys_traces':ephys_traces,'ephys_trace_times':ephys_trace_times}
    if show_stimulus:
        ephysdata['ephys_traces_stimulus'] = ephys_traces_stim

    movie_nums, movie_start_times,movie_frame_rates,movie_frame_nums = ((imaging.Movie()*imaging_gt.GroundTruthROI())&key_cell).fetch('movie_number','movie_start_time','movie_frame_rate','movie_frame_num')
    movie_start_times = np.asarray(movie_start_times, float)
    movie_end_times = np.asarray(movie_start_times, float)+np.asarray(movie_frame_nums, float)/np.asarray(movie_frame_rates, float)
    needed_start_time = session_time_to_plot - trace_window/2
    needed_end_time = session_time_to_plot + trace_window/2

    movie_nums = movie_nums[((movie_start_times >= needed_start_time) & (movie_start_times <= needed_end_time)) |
                            ((movie_end_times >= needed_start_time) & (movie_end_times <= needed_end_time)) | 
                            ((movie_end_times >= needed_end_time) & (movie_start_times <= needed_start_time)) ]
    dffs=list()
    frametimes = list()
    for movie_num in movie_nums:
    #movie_num = movie_nums[(session_time_to_plot>movie_start_times)&(session_time_to_plot<movie_end_times)][0]
        key_movie = key_cell.copy()
        key_movie['movie_number'] = movie_num

        dff,gt_roi_number = ((imaging.ROI()*imaging_gt.GroundTruthROI())&key_movie).fetch1('roi_dff','roi_number')
        dff_all,roi_number_all = (imaging.ROI()&key_movie).fetch('roi_dff','roi_number')
        dff_all =dff_all[roi_number_all!=gt_roi_number]
        frame_times = ((imaging.MovieFrameTimes()*imaging_gt.GroundTruthROI())&key_movie).fetch1('frame_times') + (session_time).total_seconds() - first_movie_start_time_real #modified to cell time
        frame_idx = (time_to_plot-trace_window/2 < frame_times) & (time_to_plot+trace_window/2 > frame_times)
        
        dff_list = [dff]
        dff_list.extend(dff_all)
        prevminval = 0
        for dff_now,alpha_now in zip(dff_list,np.arange(1,1/(len(dff_list)+1),-1/(len(dff_list)+1))):
            dfftoplotnow = dff_now[frame_idx] + prevminval
            ax_ophys.plot(frame_times[frame_idx],dfftoplotnow,'g-',alpha=alpha_now)
            prevminval = np.min(dfftoplotnow) -.005
        #ax_ophys.plot(frame_times[frame_idx],dff[frame_idx],'g-')
        dffs.append(dff)
        frametimes.append(frame_times)
        if show_o_ap_peaks:
            if 'raw' in key_cell['roi_type']:
                apidxes = ((imaging.ROI()*imaging_gt.GroundTruthROI())&key_movie).fetch1('roi_spike_indices')
            else:
                apidxes = ((imaging.ROI()*imaging_gt.GroundTruthROI())&key_movie).fetch1('roi_spike_indices')-1
            oap_times = frame_times[apidxes]
            oap_vals = dff[apidxes]
            oap_needed = (time_to_plot-trace_window/2 < oap_times) & (time_to_plot+trace_window/2 > oap_times)
            oap_times = oap_times[oap_needed]
            oap_vals = oap_vals[oap_needed]
            ax_ophys.plot(oap_times,oap_vals,'ro')

    ophysdata = {'ophys_traces':dffs,'ophys_trace_times':frametimes}
    ax_ophys.autoscale(tight = True)
    
# =============================================================================
#     dfff = np.concatenate(dffs)
#     dfff[np.argmin(dfff)]=0
#     ax_ophys.set_ylim([min(dfff),max(dfff)])
# =============================================================================
    ax_ophys.set_xlim([time_to_plot-trace_window/2,time_to_plot+trace_window/2])
    ax_ophys.set_ylabel('dF/F')
    ax_ophys.spines["top"].set_visible(False)
    ax_ophys.spines["right"].set_visible(False)
    
    ax_ophys.invert_yaxis()
    
    
    ax_ephys.autoscale(tight = True)
    ax_ephys.set_xlim([time_to_plot-trace_window/2,time_to_plot+trace_window/2])
    ax_ephys.set_ylabel('Vm (mV))')
    ax_ephys.spines["top"].set_visible(False)
    ax_ephys.spines["right"].set_visible(False)
    
    if show_stimulus:
       # ax_ephys_stim.autoscale(tight = True)
        ax_ephys_stim.set_xlim([time_to_plot-trace_window/2,time_to_plot+trace_window/2])
        ax_ephys_stim.set_ylabel('Injected current (pA))')
        ax_ephys_stim.set_xlabel('time from first movie start (s)')
        ax_ephys_stim.spines["top"].set_visible(False)
        ax_ephys_stim.spines["right"].set_visible(False)
    else:
        ax_ephys.set_xlabel('time from first movie start (s)')
    outdict = {'ephys':ephysdata,'ophys':ophysdata,'figure_handle':fig}
#%%
    return outdict
コード例 #13
0
    def make(self, key):
        #%
        #key = {'subject_id': 454597, 'session': 1, 'cell_number': 0, 'motion_correction_method': 'Matlab', 'roi_type': 'SpikePursuit', 'roi_number': 1}
        #key = {'subject_id': 456462, 'session': 1, 'cell_number': 5, 'movie_number': 3, 'motion_correction_method': 'VolPy', 'roi_type': 'VolPy', 'roi_number': 1}
        if len(
                ROIEphysCorrelation & key
        ) > 0:  #  and key['roi_type'] == 'SpikePursuit' #only spikepursuit for now..
            key_to_compare = key.copy()
            del key_to_compare['roi_number']
            #print(key)
            #%
            roinums = np.unique(
                (ROIEphysCorrelation() & key_to_compare).fetch('roi_number'))
            snratios_mean = list()
            snratios_median = list()
            snratios_first50 = list()
            for roinum_now in roinums:

                snratios = (ROIAPWave() & key_to_compare
                            & 'roi_number = {}'.format(roinum_now)
                            ).fetch('apwave_snratio')
                snratios_mean.append(np.mean(snratios))
                snratios_median.append(np.median(snratios))
                snratios_first50.append(np.mean(snratios[:50]))

            #%%
            if np.max(
                (ROIEphysCorrelation() & key).fetch('roi_number')
            ) == roinums[np.argmax(
                    snratios_first50
            )]:  #np.max((imaging_gt.ROIEphysCorrelation()&key).fetch('roi_number')) == np.min((imaging_gt.ROIEphysCorrelation&key_to_compare).fetch('roi_number')):#np.max(np.abs((ROIEphysCorrelation&key).fetch('corr_coeff'))) == np.max(np.abs((ROIEphysCorrelation&key_to_compare).fetch('corr_coeff'))):
                print('this is it')
                print(key['roi_type'])
                cellstarttime = (ephys_patch.Cell()
                                 & key).fetch1('cell_recording_start')
                sessionstarttime = (experiment.Session()
                                    & key).fetch1('session_time')
                aptimes = np.asarray(
                    (ephysanal.ActionPotential() & key).fetch('ap_max_time'),
                    float) + (cellstarttime -
                              sessionstarttime).total_seconds()
                sweep_start_times, sweep_end_times = (ephys_patch.Sweep()
                                                      & key).fetch(
                                                          'sweep_start_time',
                                                          'sweep_end_time')
                sweep_start_times = np.asarray(sweep_start_times, float) + (
                    cellstarttime - sessionstarttime).total_seconds()
                sweep_end_times = np.asarray(sweep_end_times, float) + (
                    cellstarttime - sessionstarttime).total_seconds()
                frame_timess, roi_spike_indicess = (
                    imaging.MovieFrameTimes() * imaging.Movie() *
                    imaging.ROI() & key).fetch('frame_times',
                                               'roi_spike_indices')
                movie_start_times = list()
                movie_end_times = list()
                roi_ap_times = list()
                for frame_times, roi_spike_indices in zip(
                        frame_timess, roi_spike_indicess):
                    movie_start_times.append(frame_times[0])
                    movie_end_times.append(frame_times[-1])
                    roi_ap_times.append(frame_times[roi_spike_indices])
                movie_start_times = np.sort(movie_start_times)
                movie_end_times = np.sort(movie_end_times)
                roi_ap_times = np.sort(np.concatenate(roi_ap_times))
                #%
                ##delete spikes in optical traces where there was no ephys recording
                for start_t, end_t in zip(
                        np.concatenate([sweep_start_times, [np.inf]]),
                        np.concatenate([[0], sweep_end_times])):
                    idxtodel = np.where((roi_ap_times > end_t)
                                        & (roi_ap_times < start_t))[0]
                    if len(idxtodel) > 0:
                        roi_ap_times = np.delete(roi_ap_times, idxtodel)
                ##delete spikes in ephys traces where there was no imaging
                for start_t, end_t in zip(
                        np.concatenate([movie_start_times, [np.inf]]),
                        np.concatenate([[0], movie_end_times])):
                    idxtodel = np.where((aptimes > end_t)
                                        & (aptimes < start_t))[0]
                    if len(idxtodel) > 0:
                        #print(idxtodel)
                        aptimes = np.delete(aptimes, idxtodel)
                        #%
                D = np.zeros([len(aptimes), len(roi_ap_times)])
                for idx, apt in enumerate(aptimes):
                    D[idx, :] = (roi_ap_times - apt) * 1000
                D_test = np.abs(D)
                D_test[D_test > 15] = 1000
                D_test[D < -1] = 1000
                X = scipy.optimize.linear_sum_assignment(D_test)
                #%
                cost = D_test[X[0], X[1]]
                unmatched = np.where(cost == 1000)[0]
                X0_final = np.delete(X[0], unmatched)
                X1_final = np.delete(X[1], unmatched)
                ephys_ap_times = aptimes[X0_final]
                ophys_ap_times = roi_ap_times[X1_final]
                false_positive_time_imaging = list()
                for roi_ap_time in roi_ap_times:
                    if roi_ap_time not in ophys_ap_times:
                        false_positive_time_imaging.append(roi_ap_time)
                false_negative_time_ephys = list()
                for aptime in aptimes:
                    if aptime not in ephys_ap_times:
                        false_negative_time_ephys.append(aptime)

                key['ephys_matched_ap_times'] = ephys_ap_times
                key['ophys_matched_ap_times'] = ophys_ap_times
                key['ephys_unmatched_ap_times'] = false_negative_time_ephys
                key['ophys_unmatched_ap_times'] = false_positive_time_imaging
                #print(imaging.ROI()&key)
                #print([len(aptimes),'vs',len(roi_ap_times)])
                #%%
                self.insert1(key, skip_duplicates=True)
コード例 #14
0
    def make(self, key):
        water_res_num, sess_date = get_wr_sessdate(key)
        sess_dir = store_stage / water_res_num / sess_date
        sess_dir.mkdir(parents=True, exist_ok=True)

        # -- Plotting --
        fig = plt.figure(figsize=(20, 12))
        fig.suptitle(f'{water_res_num}, session {key["session"]}')
        gs = GridSpec(5,
                      1,
                      wspace=0.4,
                      hspace=0.4,
                      bottom=0.07,
                      top=0.95,
                      left=0.1,
                      right=0.9)

        ax1 = fig.add_subplot(gs[0:3, :])
        ax2 = fig.add_subplot(gs[3, :])
        ax3 = fig.add_subplot(gs[4, :])
        ax1.get_shared_x_axes().join(ax1, ax2, ax3)

        # Plot settings
        plot_setting = {'left lick': 'red', 'right lick': 'blue'}

        # -- Get event times --
        key_subject_id_session = (
            experiment.Session() &
            (lab.WaterRestriction()
             & 'water_restriction_number="{}"'.format(water_res_num))
            & 'session="{}"'.format(key['session'])).fetch1("KEY")
        go_cue_times = (experiment.TrialEvent() & key_subject_id_session
                        & 'trial_event_type="go"').fetch(
                            'trial_event_time', order_by='trial').astype(float)
        lick_times = pd.DataFrame(
            (experiment.ActionEvent()
             & key_subject_id_session).fetch(order_by='trial'))

        trial_num = len(go_cue_times)
        all_trial_num = np.arange(1, trial_num + 1).tolist()
        all_trial_start = [[-x] for x in go_cue_times]
        all_lick = dict()
        for event_type in plot_setting:
            all_lick[event_type] = []
            for i, trial_start in enumerate(all_trial_start):
                all_lick[event_type].append(
                    (lick_times[(lick_times['trial'] == i + 1) &
                                (lick_times['action_event_type'] == event_type
                                 )]['action_event_time'].values.astype(float) +
                     trial_start).tolist())

        # -- All licking events (Ordered by trials) --
        ax1.plot([0, 0], [0, trial_num], 'k', lw=0.5)  # Aligned by go cue
        ax1.set(ylabel='Trial number', xlim=(-3, 3), xticks=[])

        # Batch plotting to speed up
        ax1.eventplot(lineoffsets=all_trial_num,
                      positions=all_trial_start,
                      color='k')  # Aligned by go cue
        for event_type in plot_setting:
            ax1.eventplot(lineoffsets=all_trial_num,
                          positions=all_lick[event_type],
                          color=plot_setting[event_type],
                          linewidth=2)  # Trial start

        # -- Histogram of all licks --
        for event_type in plot_setting:
            sns.histplot(np.hstack(all_lick[event_type]),
                         binwidth=0.01,
                         alpha=0.5,
                         ax=ax2,
                         color=plot_setting[event_type],
                         label=event_type)  # 10-ms window

        ymax_tmp = max(ax2.get_ylim())
        sns.histplot(-go_cue_times,
                     binwidth=0.01,
                     color='k',
                     ax=ax2,
                     label='trial start')  # 10-ms window
        ax2.axvline(x=0, color='k', lw=0.5)
        ax2.set(ylim=(0, ymax_tmp), xticks=[],
                title='All events')  # Fix the ylim of left and right licks
        ax2.legend()

        # -- Histogram of reaction time (first lick after go cue) --
        plot_setting = {'LEFT': 'red', 'RIGHT': 'blue'}
        for water_port in plot_setting:
            this_RT = (foraging_analysis.TrialStats() & key_subject_id_session
                       & (experiment.WaterPortChoice()
                          & 'water_port="{}"'.format(water_port))
                       ).fetch('reaction_time').astype(float)
            sns.histplot(this_RT,
                         binwidth=0.01,
                         alpha=0.5,
                         ax=ax3,
                         color=plot_setting[water_port],
                         label=water_port)  # 10-ms window
        ax3.axvline(x=0, color='k', lw=0.5)
        ax3.set(xlabel='Time to Go Cue (s)',
                title='First lick (reaction time)'
                )  # Fix the ylim of left and right licks
        ax3.legend()

        # ---- Save fig and insert ----
        fn_prefix = f'{water_res_num}_{sess_date}_'
        fig_dict = save_figs((fig, ), ('session_foraging_licking_psth', ),
                             sess_dir, fn_prefix)
        plt.close('all')
        self.insert1({**key, **fig_dict})
コード例 #15
0
                    'sweep number from imaging start')
                axs_coeff_sweep[-1].set_ylabel('correlation coefficient')
#%% photocurrent
window = 3  #seconds
roi_type = 'Spikepursuit'
key = {'roi_type': roi_type}
gtrois = (imaging_gt.GroundTruthROI() & key).fetch('subject_id',
                                                   'session',
                                                   'cell_number',
                                                   'movie_number',
                                                   'motion_correction_method',
                                                   'roi_type',
                                                   'roi_number',
                                                   as_dict=True)
for roi in gtrois:
    session_time = (experiment.Session() & roi).fetch('session_time')[0]
    cell_time = (ephys_patch.Cell() & roi).fetch('cell_recording_start')[0]
    movie_start_time = float(
        (imaging.Movie() & roi).fetch1('movie_start_time'))
    movie_start_time = session_time.total_seconds(
    ) + movie_start_time - cell_time.total_seconds()

    sweeps = (imaging_gt.ROIEphysCorrelation() & roi).fetch('sweep_number')
    sweep_now = ephys_patch.Sweep() & roi & 'sweep_number = ' + str(sweeps[0])
    trace, sr = ((ephys_patch.SweepResponse() * ephys_patch.SweepMetadata())
                 & sweep_now).fetch1('response_trace', 'sample_rate')
    sweep_start_time = float(sweep_now.fetch1('sweep_start_time'))
    trace_time = np.arange(len(trace)) / sr + sweep_start_time
    neededidx = (trace_time > movie_start_time - window) & (trace_time <
                                                            movie_start_time)
    fig = plt.figure()
コード例 #16
0
     print('no groundtruth.. skipped')
 else:
     movie_dict = dict((ephys_patch.Cell()*imaging_gt.CellMovieCorrespondance()*imaging.Movie()&movie).fetch1())
     for keynow in movie_dict.keys():
         if type(movie_dict[keynow]) == decimal.Decimal:
             movie_dict[keynow] = float(movie_dict[keynow])
         elif type(movie_dict[keynow]) == datetime.timedelta:
             movie_dict[keynow] = str(movie_dict[keynow])
         elif type(movie_dict[keynow]) == np.ndarray:
             movie_dict[keynow] = movie_dict[keynow].tolist()
             
     save_dir_base = os.path.join(gt_package_directory,str(key['subject_id']),'Cell_{}'.format(key['cell_number']),movie_dict['movie_name'])
     if ((os.path.isdir(save_dir_base) and (len(os.listdir(save_dir_base))>3 and not overwrite))):
         print('already exported, skipped')
     else:
         session_time, cell_recording_start = (experiment.Session()*ephys_patch.Cell()&key).fetch1('session_time','cell_recording_start')
         first_movie_start_time =  np.min(np.asarray(((imaging.Movie()*imaging_gt.GroundTruthROI())&key).fetch('movie_start_time'),float))
         first_movie_start_time_real = first_movie_start_time + session_time.total_seconds()
         frame_times = (imaging.MovieFrameTimes()&movie).fetch1('frame_times') - cell_recording_start.total_seconds() + session_time.total_seconds()
         motion_corr_vectors = np.asarray((imaging.MotionCorrection()*imaging.RegisteredMovie()&movie&'motion_correction_method  = "VolPy"'&'motion_corr_description= "rigid motion correction done with VolPy"').fetch1('motion_corr_vectors'))
         movie_dict['movie_start_time'] = frame_times[0]        
         movie_files = list()
         repositories , directories , fnames = (imaging.MovieFile() & movie).fetch('movie_file_repository','movie_file_directory','movie_file_name')
         for repository,directory,fname in zip(repositories,directories,fnames):
             movie_files.append(os.path.join(dj.config['locations.{}'.format(repository)],directory,fname))
         sweepdata_out = list()
         sweepmetadata_out = list()
         for sweep_number in sweep_numbers:
             #%
             key_sweep = key.copy()
             key_sweep['sweep_number'] = sweep_number
コード例 #17
0
def plot_precision_recall(key_cell,binwidth =  30,frbinwidth = 0.01,firing_rate_window = 3,save_figures = True,xlimits =None):
    #%%
    session_time, cell_recording_start = (experiment.Session()*ephys_patch.Cell()&key_cell).fetch1('session_time','cell_recording_start')
    first_movie_start_time =  np.min(np.asarray(((imaging.Movie()*imaging_gt.GroundTruthROI())&key_cell).fetch('movie_start_time'),float))
    first_movie_start_time_real = first_movie_start_time + session_time.total_seconds()
    #%
    fr_kernel = np.ones(int(firing_rate_window/frbinwidth))/(firing_rate_window/frbinwidth)
    first_movie_start_time = np.min(np.asarray(((imaging.Movie()*imaging_gt.GroundTruthROI())&key_cell).fetch('movie_start_time'),float))
    ephys_matched_ap_times,ephys_unmatched_ap_times,ophys_matched_ap_times,ophys_unmatched_ap_times = (imaging_gt.GroundTruthROI()&key_cell).fetch('ephys_matched_ap_times','ephys_unmatched_ap_times','ophys_matched_ap_times','ophys_unmatched_ap_times')
    #%
    ephys_matched_ap_times = np.concatenate(ephys_matched_ap_times) - first_movie_start_time 
    ephys_unmatched_ap_times = np.concatenate(ephys_unmatched_ap_times) - first_movie_start_time 
    all_ephys_ap_times = np.concatenate([ephys_matched_ap_times,ephys_unmatched_ap_times])
    ophys_matched_ap_times = np.concatenate(ophys_matched_ap_times) - first_movie_start_time 
    ophys_unmatched_ap_times = np.concatenate(ophys_unmatched_ap_times) - first_movie_start_time 
    all_ophys_ap_times = np.concatenate([ophys_matched_ap_times,ophys_unmatched_ap_times])
    all_times = np.concatenate([ephys_matched_ap_times,ephys_unmatched_ap_times,ophys_matched_ap_times,ophys_unmatched_ap_times])
    maxtime = np.max(all_times)
    #%
    fr_bincenters = np.arange(frbinwidth/2,maxtime+frbinwidth,frbinwidth)
    fr_binedges = np.concatenate([fr_bincenters-frbinwidth/2,[fr_bincenters[-1]+frbinwidth/2]])
    fr_e = np.histogram(all_ephys_ap_times,fr_binedges)[0]/frbinwidth
    fr_e = np.convolve(fr_e, fr_kernel,'same')
    fr_o = np.histogram(all_ophys_ap_times,fr_binedges)[0]/frbinwidth
    fr_o = np.convolve(fr_o, fr_kernel,'same')
    #%
    bincenters = np.arange(binwidth/2,maxtime+binwidth,binwidth)
    binedges = np.concatenate([bincenters-binwidth/2,[bincenters[-1]+binwidth/2]])
    ephys_matched_ap_num_binned,tmp = np.histogram(ephys_matched_ap_times,binedges)
    ephys_unmatched_ap_num_binned,tmp = np.histogram(ephys_unmatched_ap_times,binedges)
    ophys_matched_ap_num_binned,tmp = np.histogram(ophys_matched_ap_times,binedges)
    ophys_unmatched_ap_num_binned,tmp = np.histogram(ophys_unmatched_ap_times,binedges)
    precision_binned = ophys_matched_ap_num_binned/(ophys_matched_ap_num_binned+ophys_unmatched_ap_num_binned)
    recall_binned = ephys_matched_ap_num_binned/(ephys_matched_ap_num_binned+ephys_unmatched_ap_num_binned)
    f1_binned = 2*precision_binned*recall_binned/(precision_binned+recall_binned)
    #%
    if xlimits == None:
        xlimits = [binedges[0],binedges[-1]]
    
    fig=plt.figure()#figsize=(50,0)
    ax_rates = fig.add_axes([0,1.4,2,.3])
    ax_spikes = fig.add_axes([0,1,2,.3])
    ax = fig.add_axes([0,0,2,.8])
    ax_snratio = fig.add_axes([0,-1,2,.8])
    ax_latency = fig.add_axes([0,-2,2,.8])
    ax_latency_hist = fig.add_axes([.5,-3,1,.8])
    
    ax.plot(bincenters,precision_binned,'go-',label = 'precision')    
    ax.plot(bincenters,recall_binned,'ro-',label = 'recall')    
    ax.plot(bincenters,f1_binned,'bo-',label = 'f1 score')    
    ax.set_xlim([binedges[0],binedges[-1]])
    ax.set_ylim([0,1])
    ax.legend()
    ax_spikes.plot(ephys_unmatched_ap_times,np.zeros(len(ephys_unmatched_ap_times)),'r|', ms = 10)
    ax_spikes.plot(all_ephys_ap_times,np.zeros(len(all_ephys_ap_times))+.33,'k|', ms = 10)
    ax_spikes.plot(ophys_unmatched_ap_times,np.ones(len(ophys_unmatched_ap_times)),'r|',ms = 10)
    ax_spikes.plot(all_ophys_ap_times,np.ones(len(all_ophys_ap_times))-.33,'g|', ms = 10)
    ax_spikes.set_yticks([0,.33,.67,1])
    ax_spikes.set_yticklabels(['false negative','ephys','ophys','false positive'])
    ax_spikes.set_ylim([-.2,1.2])
    ax_spikes.set_xlim(xlimits)
    
    t,sn = (ephysanal.ActionPotential()*imaging_gt.GroundTruthROI()*imaging_gt.ROIAPWave()&key_cell).fetch('ap_max_time','apwave_snratio')
    t= np.asarray(t,float) + cell_recording_start.total_seconds() - first_movie_start_time_real
    ax_snratio.plot(t,sn,'ko')
    ax_snratio.set_ylabel('signal to noise ratio')
    ax_snratio.set_ylim([0,20])
    ax_snratio.set_xlim(xlimits)
    ax_rates.plot(fr_bincenters,fr_e,'k-',label = 'ephys')
    ax_rates.plot(fr_bincenters,fr_o,'g-',label = 'ophys')
    ax_rates.legend()
    ax_rates.set_xlim(xlimits)
    ax_rates.set_ylabel('Firing rate (Hz)')
    ax_rates.set_title('subject_id: %d, cell number: %s' %(key_cell['subject_id'],key_cell['cell_number']))
    ax_latency.plot(ephys_matched_ap_times,1000*(ophys_matched_ap_times-ephys_matched_ap_times),'ko')
    ax_latency.set_ylabel('ephys-ophys spike latency (ms)')
    ax_latency.set_ylim([0,10])
    ax_latency.set_xlabel('time from first movie start (s)')
    ax_latency.set_xlim(xlimits)
    
    ax_latency_hist.hist(1000*(ophys_matched_ap_times-ephys_matched_ap_times),np.arange(-5,15,.1))
    ax_latency_hist.set_xlabel('ephys-ophys spike latency (ms)')
    ax_latency_hist.set_ylabel('matched ap count')
    
    data = list()
    data.append(plot_ephys_ophys_trace(key_cell,ephys_matched_ap_times[0],trace_window = 1,show_e_ap_peaks = True,show_o_ap_peaks = True))
    data.append(plot_ephys_ophys_trace(key_cell,ephys_unmatched_ap_times[0],trace_window = 1,show_e_ap_peaks = True,show_o_ap_peaks = True))
    data.append(plot_ephys_ophys_trace(key_cell,ophys_unmatched_ap_times[0],trace_window = 1,show_e_ap_peaks = True,show_o_ap_peaks = True))
    if save_figures:
        fig.savefig('./figures/{}_cell_{}_roi_type_{}.png'.format(key_cell['subject_id'],key_cell['cell_number'],key_cell['roi_type']), bbox_inches = 'tight')
        for data_now,fname in zip(data,['matched','false_negative','false_positive']):
            data_now['figure_handle'].savefig('./figures/{}_cell_{}_roi_type_{}_{}.png'.format(key_cell['subject_id'],key_cell['cell_number'],key_cell['roi_type'],fname), bbox_inches = 'tight')
コード例 #18
0
    def make(self, key):
        log.info('BehaviorIngest.make(): key: {key}'.format(key=key))

        subject_id = key['subject_id']
        h2o = (lab.WaterRestriction() & {
            'subject_id': subject_id
        }).fetch1('water_restriction_number')

        ymd = key['session_date']
        datestr = ymd.strftime('%Y%m%d')
        log.info('h2o: {h2o}, date: {d}'.format(h2o=h2o, d=datestr))

        # session record key
        skey = {}
        skey['subject_id'] = subject_id
        skey['session_date'] = ymd
        skey['username'] = self.get_session_user()
        skey['rig'] = key['rig']

        # File paths conform to the pattern:
        # dl7/TW_autoTrain/Session Data/dl7_TW_autoTrain_20180104_132813.mat
        # which is, more generally:
        # {h2o}/{training_protocol}/Session Data/{h2o}_{training protocol}_{YYYYMMDD}_{HHMMSS}.mat

        path = pathlib.Path(key['rig_data_path'], key['subpath'])

        if experiment.Session() & skey:
            log.info("note: session exists for {h2o} on {d}".format(h2o=h2o,
                                                                    d=ymd))

        trial = namedtuple(  # simple structure to track per-trial vars
            'trial',
            ('ttype', 'stim', 'free', 'settings', 'state_times', 'state_names',
             'state_data', 'event_data', 'event_times', 'trial_start'))

        if os.stat(path).st_size / 1024 < 1000:
            log.info('skipping file {} - too small'.format(path))
            return

        log.debug('loading file {}'.format(path))

        mat = spio.loadmat(path, squeeze_me=True)
        SessionData = mat['SessionData'].flatten()

        # parse session datetime
        session_datetime_str = str('').join(
            (str(SessionData['Info'][0]['SessionDate']), ' ',
             str(SessionData['Info'][0]['SessionStartTime_UTC'])))

        session_datetime = datetime.strptime(session_datetime_str,
                                             '%d-%b-%Y %H:%M:%S')

        AllTrialTypes = SessionData['TrialTypes'][0]
        AllTrialSettings = SessionData['TrialSettings'][0]
        AllTrialStarts = SessionData['TrialStartTimestamp'][0]
        AllTrialStarts = AllTrialStarts - AllTrialStarts[0]  # real 1st trial

        RawData = SessionData['RawData'][0].flatten()
        AllStateNames = RawData['OriginalStateNamesByNumber'][0]
        AllStateData = RawData['OriginalStateData'][0]
        AllEventData = RawData['OriginalEventData'][0]
        AllStateTimestamps = RawData['OriginalStateTimestamps'][0]
        AllEventTimestamps = RawData['OriginalEventTimestamps'][0]

        # verify trial-related data arrays are all same length
        assert (all((x.shape[0] == AllStateTimestamps.shape[0]
                     for x in (AllTrialTypes, AllTrialSettings, AllStateNames,
                               AllStateData, AllEventData, AllEventTimestamps,
                               AllTrialStarts, AllTrialStarts))))

        # AllStimTrials optional special case
        if 'StimTrials' in SessionData.dtype.fields:
            log.debug('StimTrials detected in session - will include')
            AllStimTrials = SessionData['StimTrials'][0]
            assert (AllStimTrials.shape[0] == AllStateTimestamps.shape[0])
        else:
            log.debug('StimTrials not detected in session - will skip')
            AllStimTrials = np.array(
                [None for _ in enumerate(range(AllStateTimestamps.shape[0]))])

        # AllFreeTrials optional special case
        if 'FreeTrials' in SessionData.dtype.fields:
            log.debug('FreeTrials detected in session - will include')
            AllFreeTrials = SessionData['FreeTrials'][0]
            assert (AllFreeTrials.shape[0] == AllStateTimestamps.shape[0])
        else:
            log.debug('FreeTrials not detected in session - synthesizing')
            AllFreeTrials = np.zeros(AllStateTimestamps.shape[0],
                                     dtype=np.uint8)

        trials = list(
            zip(AllTrialTypes, AllStimTrials, AllFreeTrials, AllTrialSettings,
                AllStateTimestamps, AllStateNames, AllStateData, AllEventData,
                AllEventTimestamps, AllTrialStarts))

        if not trials:
            log.warning('skipping date {d}, no valid files'.format(d=date))
            return

        #
        # Trial data seems valid; synthesize session id & add session record
        # XXX: note - later breaks can result in Sessions without valid trials
        #

        assert skey['session_date'] == session_datetime.date()

        skey['session_date'] = session_datetime.date()
        skey['session_time'] = session_datetime.time()

        log.debug('synthesizing session ID')
        session = (dj.U().aggr(experiment.Session()
                               & {
                                   'subject_id': subject_id
                               },
                               n='max(session)').fetch1('n') or 0) + 1

        log.info('generated session id: {session}'.format(session=session))
        skey['session'] = session
        key = dict(key, **skey)

        #
        # Actually load the per-trial data
        #
        log.info('BehaviorIngest.make(): trial parsing phase')

        # lists of various records for batch-insert
        rows = {
            k: list()
            for k in ('trial', 'behavior_trial', 'trial_note', 'trial_event',
                      'corrected_trial_event', 'action_event', 'photostim',
                      'photostim_location', 'photostim_trial',
                      'photostim_trial_event')
        }

        i = 0  # trial numbering starts at 1
        for t in trials:

            #
            # Misc
            #

            t = trial(*t)  # convert list of items to a 'trial' structure
            i += 1  # increment trial counter

            log.debug('BehaviorIngest.make(): parsing trial {i}'.format(i=i))

            # covert state data names into a lookup dictionary
            #
            # names (seem to be? are?):
            #
            # Trigtrialstart, PreSamplePeriod, SamplePeriod, DelayPeriod
            # EarlyLickDelay, EarlyLickSample, ResponseCue, GiveLeftDrop
            # GiveRightDrop, GiveLeftDropShort, GiveRightDropShort
            # AnswerPeriod, Reward, RewardConsumption, NoResponse
            # TimeOut, StopLicking, StopLickingReturn, TrialEnd
            #

            states = {k: (v + 1) for v, k in enumerate(t.state_names)}
            required_states = ('PreSamplePeriod', 'SamplePeriod',
                               'DelayPeriod', 'ResponseCue', 'StopLicking',
                               'TrialEnd')

            missing = list(k for k in required_states if k not in states)

            if len(missing):
                log.warning('skipping trial {i}; missing {m}'.format(
                    i=i, m=missing))
                continue

            gui = t.settings['GUI'].flatten()

            # ProtocolType - only ingest protocol >= 3
            #
            # 1 Water-Valve-Calibration 2 Licking 3 Autoassist
            # 4 No autoassist 5 DelayEnforce 6 SampleEnforce 7 Fixed
            #

            if 'ProtocolType' not in gui.dtype.names:
                log.warning(
                    'skipping trial {i}; protocol undefined'.format(i=i))
                continue

            protocol_type = gui['ProtocolType'][0]
            if gui['ProtocolType'][0] < 3:
                log.warning('skipping trial {i}; protocol {n} < 3'.format(
                    i=i, n=gui['ProtocolType'][0]))
                continue

            #
            # Top-level 'Trial' record
            #

            tkey = dict(skey)
            startindex = np.where(t.state_data == states['PreSamplePeriod'])[0]

            # should be only end of 1st StopLicking;
            # rest of data is irrelevant w/r/t separately ingested ephys
            endindex = np.where(t.state_data == states['StopLicking'])[0]

            log.debug('states\n' + str(states))
            log.debug('state_data\n' + str(t.state_data))
            log.debug('startindex\n' + str(startindex))
            log.debug('endindex\n' + str(endindex))

            if not (len(startindex) and len(endindex)):
                log.warning('skipping {}: start/end mismatch: {}/{}'.format(
                    i, str(startindex), str(endindex)))
                continue

            try:
                tkey['trial'] = i
                tkey['trial_uid'] = i
                tkey['start_time'] = t.trial_start
                tkey['stop_time'] = t.trial_start + t.state_times[endindex][0]
            except IndexError:
                log.warning('skipping {}: IndexError: {}/{} -> {}'.format(
                    i, str(startindex), str(endindex), str(t.state_times)))
                continue

            log.debug('tkey' + str(tkey))
            rows['trial'].append(tkey)

            #
            # Specific BehaviorTrial information for this trial
            #

            bkey = dict(tkey)
            bkey['task'] = 'audio delay'  # hard-coded here
            bkey['task_protocol'] = 1  # hard-coded here

            # determine trial instruction
            trial_instruction = 'left'  # hard-coded here

            if gui['Reversal'][0] == 1:
                if t.ttype == 1:
                    trial_instruction = 'left'
                elif t.ttype == 0:
                    trial_instruction = 'right'
            elif gui['Reversal'][0] == 2:
                if t.ttype == 1:
                    trial_instruction = 'right'
                elif t.ttype == 0:
                    trial_instruction = 'left'

            bkey['trial_instruction'] = trial_instruction

            # determine early lick
            early_lick = 'no early'

            if (protocol_type >= 5 and 'EarlyLickDelay' in states
                    and np.any(t.state_data == states['EarlyLickDelay'])):
                early_lick = 'early'
            if (protocol_type >= 5 and
                ('EarlyLickSample' in states
                 and np.any(t.state_data == states['EarlyLickSample']))):
                early_lick = 'early'

            bkey['early_lick'] = early_lick

            # determine outcome
            outcome = 'ignore'

            if ('Reward' in states
                    and np.any(t.state_data == states['Reward'])):
                outcome = 'hit'
            elif ('TimeOut' in states
                  and np.any(t.state_data == states['TimeOut'])):
                outcome = 'miss'
            elif ('NoResponse' in states
                  and np.any(t.state_data == states['NoResponse'])):
                outcome = 'ignore'

            bkey['outcome'] = outcome

            # Determine free/autowater (Autowater 1 == enabled, 2 == disabled)
            bkey['auto_water'] = True if gui['Autowater'][0] == 1 else False
            bkey['free_water'] = t.free

            rows['behavior_trial'].append(bkey)

            #
            # Add 'protocol' note
            #
            nkey = dict(tkey)
            nkey['trial_note_type'] = 'protocol #'
            nkey['trial_note'] = str(protocol_type)
            rows['trial_note'].append(nkey)

            #
            # Add 'autolearn' note
            #
            nkey = dict(tkey)
            nkey['trial_note_type'] = 'autolearn'
            nkey['trial_note'] = str(gui['Autolearn'][0])
            rows['trial_note'].append(nkey)

            #
            # Add 'bitcode' note
            #
            if 'randomID' in gui.dtype.names:
                nkey = dict(tkey)
                nkey['trial_note_type'] = 'bitcode'
                nkey['trial_note'] = str(gui['randomID'][0])
                rows['trial_note'].append(nkey)

            #
            # Add presample event
            #
            log.debug('BehaviorIngest.make(): presample')

            ekey = dict(tkey)
            sampleindex = np.where(t.state_data == states['SamplePeriod'])[0]

            ekey['trial_event_id'] = len(rows['trial_event'])
            ekey['trial_event_type'] = 'presample'
            ekey['trial_event_time'] = t.state_times[startindex][0]
            ekey['duration'] = (t.state_times[sampleindex[0]] -
                                t.state_times[startindex])[0]

            if math.isnan(ekey['duration']):
                log.debug('BehaviorIngest.make(): fixing presample duration')
                ekey['duration'] = 0.0  # FIXDUR: lookup from previous trial

            rows['trial_event'].append(ekey)

            #
            # Add other 'sample' events
            #

            log.debug('BehaviorIngest.make(): sample events')

            last_dur = None

            for s in sampleindex:  # in protocol > 6 ~-> n>1
                # todo: batch events
                ekey = dict(tkey)
                ekey['trial_event_id'] = len(rows['trial_event'])
                ekey['trial_event_type'] = 'sample'
                ekey['trial_event_time'] = t.state_times[s]
                ekey['duration'] = gui['SamplePeriod'][0]

                if math.isnan(ekey['duration']) and last_dur is None:
                    log.warning(
                        '... trial {} bad duration, no last_edur'.format(
                            i, last_dur))
                    ekey['duration'] = 0.0  # FIXDUR: cross-trial check
                    rows['corrected_trial_event'].append(ekey)

                elif math.isnan(ekey['duration']) and last_dur is not None:
                    log.warning(
                        '... trial {} duration using last_edur {}'.format(
                            i, last_dur))
                    ekey['duration'] = last_dur
                    rows['corrected_trial_event'].append(ekey)

                else:
                    last_dur = ekey['duration']  # only track 'good' values.

                rows['trial_event'].append(ekey)

            #
            # Add 'delay' events
            #

            log.debug('BehaviorIngest.make(): delay events')

            last_dur = None
            delayindex = np.where(t.state_data == states['DelayPeriod'])[0]

            for d in delayindex:  # protocol > 6 ~-> n>1
                ekey = dict(tkey)
                ekey['trial_event_id'] = len(rows['trial_event'])
                ekey['trial_event_type'] = 'delay'
                ekey['trial_event_time'] = t.state_times[d]
                ekey['duration'] = gui['DelayPeriod'][0]

                if math.isnan(ekey['duration']) and last_dur is None:
                    log.warning('... {} bad duration, no last_edur'.format(
                        i, last_dur))
                    ekey['duration'] = 0.0  # FIXDUR: cross-trial check
                    rows['corrected_trial_event'].append(ekey)

                elif math.isnan(ekey['duration']) and last_dur is not None:
                    log.warning('... {} duration using last_edur {}'.format(
                        i, last_dur))
                    ekey['duration'] = last_dur
                    rows['corrected_trial_event'].append(ekey)

                else:
                    last_dur = ekey['duration']  # only track 'good' values.

                log.debug('delay event duration: {}'.format(ekey['duration']))
                rows['trial_event'].append(ekey)

            #
            # Add 'go' event
            #
            log.debug('BehaviorIngest.make(): go')

            ekey = dict(tkey)
            responseindex = np.where(t.state_data == states['ResponseCue'])[0]

            ekey['trial_event_id'] = len(rows['trial_event'])
            ekey['trial_event_type'] = 'go'
            ekey['trial_event_time'] = t.state_times[responseindex][0]
            ekey['duration'] = gui['AnswerPeriod'][0]

            if math.isnan(ekey['duration']):
                log.debug('BehaviorIngest.make(): fixing go duration')
                ekey['duration'] = 0.0  # FIXDUR: lookup from previous trials
                rows['corrected_trial_event'].append(ekey)

            rows['trial_event'].append(ekey)

            #
            # Add 'trialEnd' events
            #

            log.debug('BehaviorIngest.make(): trialend events')

            last_dur = None
            trialendindex = np.where(t.state_data == states['TrialEnd'])[0]

            ekey = dict(tkey)
            ekey['trial_event_id'] = len(rows['trial_event'])
            ekey['trial_event_type'] = 'trialend'
            ekey['trial_event_time'] = t.state_times[trialendindex][0]
            ekey['duration'] = 0.0

            rows['trial_event'].append(ekey)

            #
            # Add lick events
            #

            lickleft = np.where(t.event_data == 69)[0]
            log.debug('... lickleft: {r}'.format(r=str(lickleft)))

            action_event_count = len(rows['action_event'])
            if len(lickleft):
                [
                    rows['action_event'].append(
                        dict(tkey,
                             action_event_id=action_event_count + idx,
                             action_event_type='left lick',
                             action_event_time=t.event_times[l]))
                    for idx, l in enumerate(lickleft)
                ]

            lickright = np.where(t.event_data == 71)[0]
            log.debug('... lickright: {r}'.format(r=str(lickright)))

            action_event_count = len(rows['action_event'])
            if len(lickright):
                [
                    rows['action_event'].append(
                        dict(tkey,
                             action_event_id=action_event_count + idx,
                             action_event_type='right lick',
                             action_event_time=t.event_times[r]))
                    for idx, r in enumerate(lickright)
                ]

            #
            # Photostim Events
            #

            if t.stim:
                log.debug('BehaviorIngest.make(): t.stim == {}'.format(t.stim))
                rows['photostim_trial'].append(tkey)
                delay_period_idx = np.where(
                    t.state_data == states['DelayPeriod'])[0][0]
                rows['photostim_trial_event'].append(
                    dict(tkey,
                         photo_stim=t.stim,
                         photostim_event_id=len(rows['photostim_trial_event']),
                         photostim_event_time=t.state_times[delay_period_idx],
                         power=5.5))

            # end of trial loop.

        # Session Insertion

        log.info('BehaviorIngest.make(): adding session record')
        experiment.Session().insert1(skey)

        # Behavior Insertion

        log.info('BehaviorIngest.make(): bulk insert phase')

        log.info('BehaviorIngest.make(): saving ingest {d}'.format(d=key))
        self.insert1(key, ignore_extra_fields=True, allow_direct_insert=True)

        log.info('BehaviorIngest.make(): ... experiment.Session.Trial')
        experiment.SessionTrial().insert(rows['trial'],
                                         ignore_extra_fields=True,
                                         allow_direct_insert=True)

        log.info('BehaviorIngest.make(): ... experiment.BehaviorTrial')
        experiment.BehaviorTrial().insert(rows['behavior_trial'],
                                          ignore_extra_fields=True,
                                          allow_direct_insert=True)

        log.info('BehaviorIngest.make(): ... experiment.TrialNote')
        experiment.TrialNote().insert(rows['trial_note'],
                                      ignore_extra_fields=True,
                                      allow_direct_insert=True)

        log.info('BehaviorIngest.make(): ... experiment.TrialEvent')
        experiment.TrialEvent().insert(rows['trial_event'],
                                       ignore_extra_fields=True,
                                       allow_direct_insert=True,
                                       skip_duplicates=True)

        log.info('BehaviorIngest.make(): ... CorrectedTrialEvents')
        BehaviorIngest().CorrectedTrialEvents().insert(
            rows['corrected_trial_event'],
            ignore_extra_fields=True,
            allow_direct_insert=True)

        log.info('BehaviorIngest.make(): ... experiment.ActionEvent')
        experiment.ActionEvent().insert(rows['action_event'],
                                        ignore_extra_fields=True,
                                        allow_direct_insert=True)

        # Photostim Insertion

        photostim_ids = np.unique(
            [r['photo_stim'] for r in rows['photostim_trial_event']])

        unknown_photostims = np.setdiff1d(photostim_ids,
                                          list(photostims.keys()))

        if unknown_photostims:
            raise ValueError(
                'Unknown photostim protocol: {}'.format(unknown_photostims))

        if photostim_ids.size > 0:
            log.info('BehaviorIngest.make(): ... experiment.Photostim')
            for stim in photostim_ids:
                experiment.Photostim.insert1(dict(skey, **photostims[stim]),
                                             ignore_extra_fields=True)

                experiment.Photostim.PhotostimLocation.insert(
                    (dict(
                        skey, **loc, photo_stim=photostims[stim]['photo_stim'])
                     for loc in photostims[stim]['locations']),
                    ignore_extra_fields=True)

        log.info('BehaviorIngest.make(): ... experiment.PhotostimTrial')
        experiment.PhotostimTrial.insert(rows['photostim_trial'],
                                         ignore_extra_fields=True,
                                         allow_direct_insert=True)

        log.info('BehaviorIngest.make(): ... experiment.PhotostimTrialEvent')
        experiment.PhotostimEvent.insert(rows['photostim_trial_event'],
                                         ignore_extra_fields=True,
                                         allow_direct_insert=True)

        # Behavior Ingest Insertion

        log.info('BehaviorIngest.make(): ... BehaviorIngest.BehaviorFile')
        BehaviorIngest.BehaviorFile().insert1(dict(
            key, behavior_file=os.path.basename(key['subpath'])),
                                              ignore_extra_fields=True,
                                              allow_direct_insert=True)
コード例 #19
0
    def make(self, key):
        log.info('BehaviorIngest.make(): key: {key}'.format(key=key))

        subject_id = key['subject_id']
        h2o = (lab.WaterRestriction() & {
            'subject_id': subject_id
        }).fetch1('water_restriction_number')

        date = key['session_date']
        datestr = date.strftime('%Y%m%d')
        log.info('h2o: {h2o}, date: {d}'.format(h2o=h2o, d=datestr))

        # session record key
        skey = {}
        skey['subject_id'] = subject_id
        skey['session_date'] = date
        skey['username'] = self.get_session_user()

        # File paths conform to the pattern:
        # dl7/TW_autoTrain/Session Data/dl7_TW_autoTrain_20180104_132813.mat
        # which is, more generally:
        # {h2o}/{training_protocol}/Session Data/{h2o}_{training protocol}_{YYYYMMDD}_{HHMMSS}.mat
        root = pathlib.Path(key['rig_data_path'],
                            os.path.dirname(key['subpath']))
        path = root / '{h2o}_*_{d}*.mat'.format(h2o=h2o, d=datestr)

        log.info('rigpath {p}'.format(p=path))

        matches = sorted(
            root.glob('{h2o}_*_{d}*.mat'.format(h2o=h2o, d=datestr)))
        if matches:
            log.info('found files: {}, this is the rig'.format(matches))
            skey['rig'] = key['rig']
        else:
            log.info('no file matches found in {p}'.format(p=path))

        if not len(matches):
            log.warning('no file matches found for {h2o} / {d}'.format(
                h2o=h2o, d=datestr))
            return

        #
        # Find files & Check for split files
        # XXX: not checking rig.. 2+ sessions on 2+ rigs possible for date?
        #

        if len(matches) > 1:
            log.warning(
                'split session case detected for {h2o} on {date}'.format(
                    h2o=h2o, date=date))

        # session:date relationship is 1:1; skip if we have a session
        if experiment.Session() & skey:
            log.warning("Warning! session exists for {h2o} on {d}".format(
                h2o=h2o, d=date))
            return

        #
        # Prepare PhotoStim
        #
        photosti_duration = 0.5  # (s) Hard-coded here
        photostims = {
            4: {
                'photo_stim': 4,
                'photostim_device': 'OBIS470',
                'brain_location_name': 'left_alm',
                'duration': photosti_duration
            },
            5: {
                'photo_stim': 5,
                'photostim_device': 'OBIS470',
                'brain_location_name': 'right_alm',
                'duration': photosti_duration
            },
            6: {
                'photo_stim': 6,
                'photostim_device': 'OBIS470',
                'brain_location_name': 'both_alm',
                'duration': photosti_duration
            }
        }

        #
        # Extract trial data from file(s) & prepare trial loop
        #

        trials = zip()

        trial = namedtuple(  # simple structure to track per-trial vars
            'trial',
            ('ttype', 'stim', 'settings', 'state_times', 'state_names',
             'state_data', 'event_data', 'event_times'))

        for f in matches:

            if os.stat(f).st_size / 1024 < 1000:
                log.info('skipping file {f} - too small'.format(f=f))
                continue

            log.debug('loading file {}'.format(f))

            mat = spio.loadmat(f, squeeze_me=True)
            SessionData = mat['SessionData'].flatten()

            AllTrialTypes = SessionData['TrialTypes'][0]
            AllTrialSettings = SessionData['TrialSettings'][0]

            RawData = SessionData['RawData'][0].flatten()
            AllStateNames = RawData['OriginalStateNamesByNumber'][0]
            AllStateData = RawData['OriginalStateData'][0]
            AllEventData = RawData['OriginalEventData'][0]
            AllStateTimestamps = RawData['OriginalStateTimestamps'][0]
            AllEventTimestamps = RawData['OriginalEventTimestamps'][0]

            # verify trial-related data arrays are all same length
            assert (all(
                (x.shape[0] == AllStateTimestamps.shape[0]
                 for x in (AllTrialTypes, AllTrialSettings, AllStateNames,
                           AllStateData, AllEventData, AllEventTimestamps))))

            if 'StimTrials' in SessionData.dtype.fields:
                log.debug('StimTrials detected in session - will include')
                AllStimTrials = SessionData['StimTrials'][0]
                assert (AllStimTrials.shape[0] == AllStateTimestamps.shape[0])
            else:
                log.debug('StimTrials not detected in session - will skip')
                AllStimTrials = np.array([
                    None for i in enumerate(range(AllStateTimestamps.shape[0]))
                ])

            z = zip(AllTrialTypes, AllStimTrials, AllTrialSettings,
                    AllStateTimestamps, AllStateNames, AllStateData,
                    AllEventData, AllEventTimestamps)

            trials = chain(trials, z)  # concatenate the files

        trials = list(trials)

        # all files were internally invalid or size < 100k
        if not trials:
            log.warning('skipping date {d}, no valid files'.format(d=date))
            return

        #
        # Trial data seems valid; synthesize session id & add session record
        # XXX: note - later breaks can result in Sessions without valid trials
        #

        log.debug('synthesizing session ID')
        session = (dj.U().aggr(experiment.Session() & {
            'subject_id': subject_id
        },
                               n='max(session)').fetch1('n') or 0) + 1
        log.info('generated session id: {session}'.format(session=session))
        skey['session'] = session
        key = dict(key, **skey)

        #
        # Actually load the per-trial data
        #
        log.info('BehaviorIngest.make(): trial parsing phase')

        # lists of various records for batch-insert
        rows = {
            k: list()
            for k in ('trial', 'behavior_trial', 'trial_note', 'trial_event',
                      'corrected_trial_event', 'action_event', 'photostim',
                      'photostim_location', 'photostim_trial',
                      'photostim_trial_event')
        }

        i = -1
        for t in trials:

            #
            # Misc
            #

            t = trial(*t)  # convert list of items to a 'trial' structure
            i += 1  # increment trial counter

            log.debug('BehaviorIngest.make(): parsing trial {i}'.format(i=i))

            # covert state data names into a lookup dictionary
            #
            # names (seem to be? are?):
            #
            # Trigtrialstart
            # PreSamplePeriod
            # SamplePeriod
            # DelayPeriod
            # EarlyLickDelay
            # EarlyLickSample
            # ResponseCue
            # GiveLeftDrop
            # GiveRightDrop
            # GiveLeftDropShort
            # GiveRightDropShort
            # AnswerPeriod
            # Reward
            # RewardConsumption
            # NoResponse
            # TimeOut
            # StopLicking
            # StopLickingReturn
            # TrialEnd

            states = {k: (v + 1) for v, k in enumerate(t.state_names)}
            required_states = ('PreSamplePeriod', 'SamplePeriod',
                               'DelayPeriod', 'ResponseCue', 'StopLicking',
                               'TrialEnd')

            missing = list(k for k in required_states if k not in states)

            if len(missing):
                log.warning('skipping trial {i}; missing {m}'.format(
                    i=i, m=missing))
                continue

            gui = t.settings['GUI'].flatten()

            # ProtocolType - only ingest protocol >= 3
            #
            # 1 Water-Valve-Calibration 2 Licking 3 Autoassist
            # 4 No autoassist 5 DelayEnforce 6 SampleEnforce 7 Fixed
            #

            if 'ProtocolType' not in gui.dtype.names:
                log.warning(
                    'skipping trial {i}; protocol undefined'.format(i=i))
                continue

            protocol_type = gui['ProtocolType'][0]
            if gui['ProtocolType'][0] < 3:
                log.warning('skipping trial {i}; protocol {n} < 3'.format(
                    i=i, n=gui['ProtocolType'][0]))
                continue

            #
            # Top-level 'Trial' record
            #

            tkey = dict(skey)
            startindex = np.where(t.state_data == states['PreSamplePeriod'])[0]

            # should be only end of 1st StopLicking;
            # rest of data is irrelevant w/r/t separately ingested ephys
            endindex = np.where(t.state_data == states['StopLicking'])[0]

            log.debug('states\n' + str(states))
            log.debug('state_data\n' + str(t.state_data))
            log.debug('startindex\n' + str(startindex))
            log.debug('endindex\n' + str(endindex))

            if not (len(startindex) and len(endindex)):
                log.warning(
                    'skipping trial {i}: start/end index error: {s}/{e}'.
                    format(i=i, s=str(startindex), e=str(endindex)))
                continue

            try:
                tkey['trial'] = i
                tkey[
                    'trial_uid'] = i  # Arseny has unique id to identify some trials
                tkey['start_time'] = t.state_times[startindex][0]
                tkey['stop_time'] = t.state_times[endindex][0]
            except IndexError:
                log.warning(
                    'skipping trial {i}: error indexing {s}/{e} into {t}'.
                    format(i=i,
                           s=str(startindex),
                           e=str(endindex),
                           t=str(t.state_times)))
                continue

            log.debug('BehaviorIngest.make(): Trial().insert1')  # TODO msg
            log.debug('tkey' + str(tkey))
            rows['trial'].append(tkey)

            #
            # Specific BehaviorTrial information for this trial
            #

            bkey = dict(tkey)
            bkey['task'] = 'audio delay'  # hard-coded here
            bkey['task_protocol'] = 1  # hard-coded here

            # determine trial instruction
            trial_instruction = 'left'  # hard-coded here

            if gui['Reversal'][0] == 1:
                if t.ttype == 1:
                    trial_instruction = 'left'
                elif t.ttype == 0:
                    trial_instruction = 'right'
            elif gui['Reversal'][0] == 2:
                if t.ttype == 1:
                    trial_instruction = 'right'
                elif t.ttype == 0:
                    trial_instruction = 'left'

            bkey['trial_instruction'] = trial_instruction

            # determine early lick
            early_lick = 'no early'

            if (protocol_type >= 5 and 'EarlyLickDelay' in states
                    and np.any(t.state_data == states['EarlyLickDelay'])):
                early_lick = 'early'
            if (protocol_type > 5 and
                ('EarlyLickSample' in states
                 and np.any(t.state_data == states['EarlyLickSample']))):
                early_lick = 'early'

            bkey['early_lick'] = early_lick

            # determine outcome
            outcome = 'ignore'

            if ('Reward' in states
                    and np.any(t.state_data == states['Reward'])):
                outcome = 'hit'
            elif ('TimeOut' in states
                  and np.any(t.state_data == states['TimeOut'])):
                outcome = 'miss'
            elif ('NoResponse' in states
                  and np.any(t.state_data == states['NoResponse'])):
                outcome = 'ignore'

            bkey['outcome'] = outcome
            rows['behavior_trial'].append(bkey)

            #
            # Add 'protocol' note
            #
            nkey = dict(tkey)
            nkey['trial_note_type'] = 'protocol #'
            nkey['trial_note'] = str(protocol_type)
            rows['trial_note'].append(nkey)

            #
            # Add 'autolearn' note
            #
            nkey = dict(tkey)
            nkey['trial_note_type'] = 'autolearn'
            nkey['trial_note'] = str(gui['Autolearn'][0])
            rows['trial_note'].append(nkey)

            #
            # Add 'bitcode' note
            #
            if 'randomID' in gui.dtype.names:
                nkey = dict(tkey)
                nkey['trial_note_type'] = 'bitcode'
                nkey['trial_note'] = str(gui['randomID'][0])
                rows['trial_note'].append(nkey)

            #
            # Add presample event
            #
            log.debug('BehaviorIngest.make(): presample')

            ekey = dict(tkey)
            sampleindex = np.where(t.state_data == states['SamplePeriod'])[0]

            ekey['trial_event_id'] = len(rows['trial_event'])
            ekey['trial_event_type'] = 'presample'
            ekey['trial_event_time'] = t.state_times[startindex][0]
            ekey['duration'] = (t.state_times[sampleindex[0]] -
                                t.state_times[startindex])[0]

            if math.isnan(ekey['duration']):
                log.debug('BehaviorIngest.make(): fixing presample duration')
                ekey['duration'] = 0.0  # FIXDUR: lookup from previous trial

            rows['trial_event'].append(ekey)

            #
            # Add other 'sample' events
            #

            log.debug('BehaviorIngest.make(): sample events')

            last_dur = None

            for s in sampleindex:  # in protocol > 6 ~-> n>1
                # todo: batch events
                ekey = dict(tkey)
                ekey['trial_event_id'] = len(rows['trial_event'])
                ekey['trial_event_type'] = 'sample'
                ekey['trial_event_time'] = t.state_times[s]
                ekey['duration'] = gui['SamplePeriod'][0]

                if math.isnan(ekey['duration']) and last_dur is None:
                    log.warning(
                        '... trial {} bad duration, no last_edur'.format(
                            i, last_dur))
                    ekey['duration'] = 0.0  # FIXDUR: cross-trial check
                    rows['corrected_trial_event'].append(ekey)

                elif math.isnan(ekey['duration']) and last_dur is not None:
                    log.warning(
                        '... trial {} duration using last_edur {}'.format(
                            i, last_dur))
                    ekey['duration'] = last_dur
                    rows['corrected_trial_event'].append(ekey)

                else:
                    last_dur = ekey['duration']  # only track 'good' values.

                rows['trial_event'].append(ekey)

            #
            # Add 'delay' events
            #

            log.debug('BehaviorIngest.make(): delay events')

            last_dur = None
            delayindex = np.where(t.state_data == states['DelayPeriod'])[0]

            for d in delayindex:  # protocol > 6 ~-> n>1
                ekey = dict(tkey)
                ekey['trial_event_id'] = len(rows['trial_event'])
                ekey['trial_event_type'] = 'delay'
                ekey['trial_event_time'] = t.state_times[d]
                ekey['duration'] = gui['DelayPeriod'][0]

                if math.isnan(ekey['duration']) and last_dur is None:
                    log.warning('... {} bad duration, no last_edur'.format(
                        i, last_dur))
                    ekey['duration'] = 0.0  # FIXDUR: cross-trial check
                    rows['corrected_trial_event'].append(ekey)

                elif math.isnan(ekey['duration']) and last_dur is not None:
                    log.warning('... {} duration using last_edur {}'.format(
                        i, last_dur))
                    ekey['duration'] = last_dur
                    rows['corrected_trial_event'].append(ekey)

                else:
                    last_dur = ekey['duration']  # only track 'good' values.

                log.debug('delay event duration: {}'.format(ekey['duration']))
                rows['trial_event'].append(ekey)

            #
            # Add 'go' event
            #
            log.debug('BehaviorIngest.make(): go')

            ekey = dict(tkey)
            responseindex = np.where(t.state_data == states['ResponseCue'])[0]

            ekey['trial_event_id'] = len(rows['trial_event'])
            ekey['trial_event_type'] = 'go'
            ekey['trial_event_time'] = t.state_times[responseindex][0]
            ekey['duration'] = gui['AnswerPeriod'][0]

            if math.isnan(ekey['duration']):
                log.debug('BehaviorIngest.make(): fixing go duration')
                ekey['duration'] = 0.0  # FIXDUR: lookup from previous trials
                rows['corrected_trial_event'].append(ekey)

            rows['trial_event'].append(ekey)

            #
            # Add 'trialEnd' events
            #

            log.debug('BehaviorIngest.make(): trialend events')

            last_dur = None
            trialendindex = np.where(t.state_data == states['TrialEnd'])[0]

            ekey = dict(tkey)
            ekey['trial_event_id'] = len(rows['trial_event'])
            ekey['trial_event_type'] = 'trialend'
            ekey['trial_event_time'] = t.state_times[trialendindex][0]
            ekey['duration'] = 0.0

            rows['trial_event'].append(ekey)

            #
            # Add lick events
            #

            lickleft = np.where(t.event_data == 69)[0]
            log.debug('... lickleft: {r}'.format(r=str(lickleft)))

            action_event_count = len(rows['action_event'])
            if len(lickleft):
                [
                    rows['action_event'].append(
                        dict(tkey,
                             action_event_id=action_event_count + idx,
                             action_event_type='left lick',
                             action_event_time=t.event_times[l]))
                    for idx, l in enumerate(lickleft)
                ]

            lickright = np.where(t.event_data == 71)[0]
            log.debug('... lickright: {r}'.format(r=str(lickright)))

            action_event_count = len(rows['action_event'])
            if len(lickright):
                [
                    rows['action_event'].append(
                        dict(tkey,
                             action_event_id=action_event_count + idx,
                             action_event_type='right lick',
                             action_event_time=t.event_times[r]))
                    for idx, r in enumerate(lickright)
                ]

            # Photostim Events
            #
            # TODO:
            #
            # - base stimulation parameters:
            #
            #   - should be loaded elsewhere - where
            #   - actual ccf locations - cannot be known apriori apparently?
            #   - Photostim.Profile: what is? fix/add
            #
            # - stim data
            #
            #   - how retrieve power from file (didn't see) or should
            #     be statically coded here?
            #   - how encode stim type 6?
            #     - we have hemisphere as boolean or
            #     - but adding an event 4 and event 5 means querying
            #       is less straightforwrard (e.g. sessions with 5 & 6)

            if t.stim:
                log.info('BehaviorIngest.make(): t.stim == {}'.format(t.stim))
                rows['photostim_trial'].append(tkey)
                delay_period_idx = np.where(
                    t.state_data == states['DelayPeriod'])[0][0]
                rows['photostim_trial_event'].append(
                    dict(tkey,
                         **photostims[t.stim],
                         photostim_event_id=len(rows['photostim_trial_event']),
                         photostim_event_time=t.state_times[delay_period_idx],
                         power=5.5))

            # end of trial loop.

        # Session Insertion

        log.info('BehaviorIngest.make(): adding session record')
        experiment.Session().insert1(skey)

        # Behavior Insertion

        log.info('BehaviorIngest.make(): bulk insert phase')

        log.info('BehaviorIngest.make(): saving ingest {d}'.format(d=key))
        self.insert1(key, ignore_extra_fields=True, allow_direct_insert=True)

        log.info('BehaviorIngest.make(): ... experiment.Session.Trial')
        experiment.SessionTrial().insert(rows['trial'],
                                         ignore_extra_fields=True,
                                         allow_direct_insert=True)

        log.info('BehaviorIngest.make(): ... experiment.BehaviorTrial')
        experiment.BehaviorTrial().insert(rows['behavior_trial'],
                                          ignore_extra_fields=True,
                                          allow_direct_insert=True)

        log.info('BehaviorIngest.make(): ... experiment.TrialNote')
        experiment.TrialNote().insert(rows['trial_note'],
                                      ignore_extra_fields=True,
                                      allow_direct_insert=True)

        log.info('BehaviorIngest.make(): ... experiment.TrialEvent')
        experiment.TrialEvent().insert(rows['trial_event'],
                                       ignore_extra_fields=True,
                                       allow_direct_insert=True,
                                       skip_duplicates=True)

        log.info('BehaviorIngest.make(): ... CorrectedTrialEvents')
        BehaviorIngest().CorrectedTrialEvents().insert(
            rows['corrected_trial_event'],
            ignore_extra_fields=True,
            allow_direct_insert=True)

        log.info('BehaviorIngest.make(): ... experiment.ActionEvent')
        experiment.ActionEvent().insert(rows['action_event'],
                                        ignore_extra_fields=True,
                                        allow_direct_insert=True)

        BehaviorIngest.BehaviorFile().insert(
            (dict(key, behavior_file=f.name) for f in matches),
            ignore_extra_fields=True,
            allow_direct_insert=True)

        # Photostim Insertion

        photostim_ids = set(
            [r['photo_stim'] for r in rows['photostim_trial_event']])
        if photostim_ids:
            log.info('BehaviorIngest.make(): ... experiment.Photostim')
            experiment.Photostim.insert(
                (dict(skey, **photostims[stim]) for stim in photostim_ids),
                ignore_extra_fields=True)

        log.info('BehaviorIngest.make(): ... experiment.PhotostimTrial')
        experiment.PhotostimTrial.insert(rows['photostim_trial'],
                                         ignore_extra_fields=True,
                                         allow_direct_insert=True)

        log.info('BehaviorIngest.make(): ... experiment.PhotostimTrialEvent')
        experiment.PhotostimEvent.insert(rows['photostim_trial_event'],
                                         ignore_extra_fields=True,
                                         allow_direct_insert=True)
コード例 #20
0
    def make(self, key):
        log.info('BehaviorIngest.make(): key: {key}'.format(key=key))
        rigpaths = [
            p for p in RigDataPath().fetch(order_by='rig_data_path')
            if 'RRig' in p['rig']
        ]  # change between TRig and RRig

        subject_id = key['subject_id']
        h2o = (lab.WaterRestriction() & {
            'subject_id': subject_id
        }).fetch1('water_restriction_number')
        date = key['session_date']
        datestr = date.strftime('%Y%m%d')
        log.debug('h2o: {h2o}, date: {d}'.format(h2o=h2o, d=datestr))

        # session record key
        skey = {}
        skey['subject_id'] = subject_id
        skey['session_date'] = date
        skey['username'] = '******'  # username has to be changed

        # e.g: dl7/TW_autoTrain/Session Data/dl7_TW_autoTrain_20180104_132813.mat
        #         # p.split('/foo/bar')[1]
        for rp in rigpaths:
            root = rp['rig_data_path']
            path = root
            path = os.path.join(path, h2o)
            #            path = os.path.join(path, 'TW_autoTrain')
            path = os.path.join(path, 'tw2')
            path = os.path.join(path, 'Session Data')
            path = os.path.join(
                #                path, '{h2o}_TW_autoTrain_{d}*.mat'.format(h2o=h2o, d=datestr)) # earlier program protocol
                path,
                '{h2o}_tw2_{d}*.mat'.format(
                    h2o=h2o, d=datestr))  # later program protocol

            log.debug('rigpath {p}'.format(p=path))

            matches = glob.glob(path)
            if len(matches):
                log.debug('found files, this is the rig')
                skey['rig'] = rp['rig']
                break
            else:
                log.info('no file matches found in {p}'.format(p=path))

        if not len(matches):
            log.warning('no file matches found for {h2o} / {d}'.format(
                h2o=h2o, d=datestr))
            return

        #
        # Find files & Check for split files
        # XXX: not checking rig.. 2+ sessions on 2+ rigs possible for date?
        #

        if len(matches) > 1:
            log.warning(
                'split session case detected for {h2o} on {date}'.format(
                    h2o=h2o, date=date))

        # session:date relationship is 1:1; skip if we have a session
        if experiment.Session() & skey:
            log.warning("Warning! session exists for {h2o} on {d}".format(
                h2o=h2o, d=date))
            return

        #
        # Extract trial data from file(s) & prepare trial loop
        #

        trials = zip()

        trial = namedtuple(  # simple structure to track per-trial vars
            'trial', ('ttype', 'settings', 'state_times', 'state_names',
                      'state_data', 'event_data', 'event_times'))

        for f in matches:

            if os.stat(f).st_size / 1024 < 100:
                log.info('skipping file {f} - too small'.format(f=f))
                continue

            mat = spio.loadmat(f, squeeze_me=True)
            SessionData = mat['SessionData'].flatten()

            AllTrialTypes = SessionData['TrialTypes'][0]
            AllTrialSettings = SessionData['TrialSettings'][0]

            RawData = SessionData['RawData'][0].flatten()
            AllStateNames = RawData['OriginalStateNamesByNumber'][0]
            AllStateData = RawData['OriginalStateData'][0]
            AllEventData = RawData['OriginalEventData'][0]
            AllStateTimestamps = RawData['OriginalStateTimestamps'][0]
            AllEventTimestamps = RawData['OriginalEventTimestamps'][0]

            # verify trial-related data arrays are all same length
            assert (all(
                (x.shape[0] == AllStateTimestamps.shape[0]
                 for x in (AllTrialTypes, AllTrialSettings, AllStateNames,
                           AllStateData, AllEventData, AllEventTimestamps))))

            z = zip(AllTrialTypes, AllTrialSettings, AllStateTimestamps,
                    AllStateNames, AllStateData, AllEventData,
                    AllEventTimestamps)

            trials = chain(trials, z)  # concatenate the files

        trials = list(trials)

        # all files were internally invalid or size < 100k
        if not trials:
            log.warning('skipping date {d}, no valid files'.format(d=date))

        #
        # Trial data seems valid; synthesize session id & add session record
        # XXX: note - later breaks can result in Sessions without valid trials
        #

        log.debug('synthesizing session ID')
        session = (dj.U().aggr(experiment.Session() & {
            'subject_id': subject_id
        },
                               n='max(session)').fetch1('n') or 0) + 1
        log.info('generated session id: {session}'.format(session=session))
        skey['session'] = session
        key = dict(key, **skey)

        log.debug('BehaviorIngest.make(): adding session record')
        experiment.Session().insert1(skey)

        #
        # Actually load the per-trial data
        #
        log.info('BehaviorIngest.make(): trial parsing phase')

        # lists of various records for batch-insert
        rows = {
            k: list()
            for k in ('trial', 'behavior_trial', 'trial_note', 'trial_event',
                      'action_event')
        }

        i = -1
        for t in trials:

            #
            # Misc
            #

            t = trial(*t)  # convert list of items to a 'trial' structure
            i += 1  # increment trial counter

            log.info('BehaviorIngest.make(): parsing trial {i}'.format(i=i))

            # covert state data names into a lookup dictionary
            #
            # names (seem to be? are?):
            #
            # Trigtrialstart
            # PreSamplePeriod
            # SamplePeriod
            # DelayPeriod
            # EarlyLickDelay
            # EarlyLickSample
            # ResponseCue
            # GiveLeftDrop
            # GiveRightDrop
            # GiveLeftDropShort
            # GiveRightDropShort
            # AnswerPeriod
            # Reward
            # RewardConsumption
            # NoResponse
            # TimeOut
            # StopLicking
            # StopLickingReturn
            # TrialEnd

            states = {k: (v + 1) for v, k in enumerate(t.state_names)}
            required_states = ('PreSamplePeriod', 'SamplePeriod',
                               'DelayPeriod', 'ResponseCue', 'StopLicking',
                               'TrialEnd')

            missing = list(k for k in required_states if k not in states)

            if len(missing):
                log.info('skipping trial {i}; missing {m}'.format(i=i,
                                                                  m=missing))
                continue

            gui = t.settings['GUI'].flatten()

            # ProtocolType - only ingest protocol >= 3
            #
            # 1 Water-Valve-Calibration 2 Licking 3 Autoassist
            # 4 No autoassist 5 DelayEnforce 6 SampleEnforce 7 Fixed
            #

            if 'ProtocolType' not in gui.dtype.names:
                log.info('skipping trial {i}; protocol undefined'.format(i=i))
                continue

            protocol_type = gui['ProtocolType'][0]
            if gui['ProtocolType'][0] < 3:
                log.info('skipping trial {i}; protocol {n} < 3'.format(
                    i=i, n=gui['ProtocolType'][0]))
                continue

            #
            # Top-level 'Trial' record
            #

            tkey = dict(skey)
            startindex = np.where(t.state_data == states['PreSamplePeriod'])[0]

            # should be only end of 1st StopLicking;
            # rest of data is irrelevant w/r/t separately ingested ephys
            endindex = np.where(t.state_data == states['StopLicking'])[0]

            log.debug('states\n' + str(states))
            log.debug('state_data\n' + str(t.state_data))
            log.debug('startindex\n' + str(startindex))
            log.debug('endendex\n' + str(endindex))

            if not (len(startindex) and len(endindex)):
                log.info('skipping trial {i}: start/end index error: {s}/{e}'.
                         format(i=i, s=str(startindex), e=str(endindex)))
                continue

            try:
                tkey['trial'] = i
                tkey['trial_uid'] = i
                tkey['start_time'] = t.state_times[startindex][0]
            except IndexError:
                log.info('skipping trial {i}: error indexing {s}/{e} into {t}'.
                         format(i=i,
                                s=str(startindex),
                                e=str(endindex),
                                t=str(t.state_times)))
                continue

            log.debug('BehaviorIngest.make(): Trial().insert1')  # TODO msg
            log.debug('tkey' + str(tkey))
            rows['trial'].append(tkey)

            #
            # Specific BehaviorTrial information for this trial
            #

            bkey = dict(tkey)
            bkey['task'] = 'audio delay'
            bkey['task_protocol'] = 1

            # determine trial instruction
            trial_instruction = 'left'

            if gui['Reversal'][0] == 1:
                if t.ttype == 1:
                    trial_instruction = 'left'
                elif t.ttype == 0:
                    trial_instruction = 'right'
            elif gui['Reversal'][0] == 2:
                if t.ttype == 1:
                    trial_instruction = 'right'
                elif t.ttype == 0:
                    trial_instruction = 'left'

            bkey['trial_instruction'] = trial_instruction

            # determine early lick
            early_lick = 'no early'

            if (protocol_type >= 5 and 'EarlyLickDelay' in states
                    and np.any(t.state_data == states['EarlyLickDelay'])):
                early_lick = 'early'
            if (protocol_type > 5 and
                ('EarlyLickSample' in states
                 and np.any(t.state_data == states['EarlyLickSample']))):
                early_lick = 'early'

            bkey['early_lick'] = early_lick

            # determine outcome
            outcome = 'ignore'

            if ('Reward' in states
                    and np.any(t.state_data == states['Reward'])):
                outcome = 'hit'
            elif ('TimeOut' in states
                  and np.any(t.state_data == states['TimeOut'])):
                outcome = 'miss'
            elif ('NoResponse' in states
                  and np.any(t.state_data == states['NoResponse'])):
                outcome = 'ignore'

            bkey['outcome'] = outcome

            # add behavior record
            log.debug('BehaviorIngest.make(): BehaviorTrial()')
            rows['behavior_trial'].append(bkey)

            #
            # Add 'protocol' note
            #

            nkey = dict(tkey)
            nkey['trial_note_type'] = 'protocol #'
            nkey['trial_note'] = str(protocol_type)

            log.debug('BehaviorIngest.make(): TrialNote().insert1')
            rows['trial_note'].append(nkey)

            #
            # Add 'autolearn' note
            #

            nkey = dict(tkey)
            nkey['trial_note_type'] = 'autolearn'
            nkey['trial_note'] = str(gui['Autolearn'][0])
            rows['trial_note'].append(nkey)

            #pdb.set_trace()
            #
            # Add 'bitcode' note
            #
            if 'randomID' in gui.dtype.names:
                nkey = dict(tkey)
                nkey['trial_note_type'] = 'bitcode'
                nkey['trial_note'] = str(gui['randomID'][0])
                rows['trial_note'].append(nkey)

            #
            # Add presample event
            #

            ekey = dict(tkey)
            sampleindex = np.where(t.state_data == states['SamplePeriod'])[0]

            ekey['trial_event_type'] = 'presample'
            ekey['trial_event_time'] = t.state_times[startindex][0]
            ekey['duration'] = (t.state_times[sampleindex[0]] -
                                t.state_times[startindex])[0]

            log.debug('BehaviorIngest.make(): presample')
            rows['trial_event'].append(ekey)

            #
            # Add 'go' event
            #

            ekey = dict(tkey)
            responseindex = np.where(t.state_data == states['ResponseCue'])[0]

            ekey['trial_event_type'] = 'go'
            ekey['trial_event_time'] = t.state_times[responseindex][0]
            ekey['duration'] = gui['AnswerPeriod'][0]

            log.debug('BehaviorIngest.make(): go')
            rows['trial_event'].append(ekey)

            #
            # Add other 'sample' events
            #

            log.debug('BehaviorIngest.make(): sample events')
            for s in sampleindex:  # in protocol > 6 ~-> n>1
                # todo: batch events
                ekey = dict(tkey)
                ekey['trial_event_type'] = 'sample'
                ekey['trial_event_time'] = t.state_times[s]
                ekey['duration'] = gui['SamplePeriod'][0]
                rows['trial_event'].append(ekey)

            #
            # Add 'delay' events
            #

            delayindex = np.where(t.state_data == states['DelayPeriod'])[0]

            log.debug('BehaviorIngest.make(): delay events')
            for d in delayindex:  # protocol > 6 ~-> n>1
                # todo: batch events
                ekey = dict(tkey)
                ekey['trial_event_type'] = 'delay'
                ekey['trial_event_time'] = t.state_times[d]
                ekey['duration'] = gui['DelayPeriod'][0]
                rows['trial_event'].append(ekey)

            #
            # Add lick events
            #

            lickleft = np.where(t.event_data == 69)[0]
            log.debug('... lickleft: {r}'.format(r=str(lickleft)))

            if len(lickleft):
                [
                    rows['action_event'].append(
                        dict(**tkey,
                             action_event_type='left lick',
                             action_event_time=t.event_times[l]))
                    for l in lickleft
                ]

            lickright = np.where(t.event_data == 70)[0]
            log.debug('... lickright: {r}'.format(r=str(lickright)))

            if len(lickright):
                [
                    rows['action_event'].append(
                        dict(**tkey,
                             action_event_type='right lick',
                             action_event_time=t.event_times[r]))
                    for r in lickright
                ]

            # end of trial loop.

        log.info('BehaviorIngest.make(): bulk insert phase')

        log.info('BehaviorIngest.make(): ... experiment.Session.Trial')
        experiment.SessionTrial().insert(rows['trial'],
                                         ignore_extra_fields=True)

        log.info('BehaviorIngest.make(): ... experiment.BehaviorTrial')
        experiment.BehaviorTrial().insert(rows['behavior_trial'],
                                          ignore_extra_fields=True)

        log.info('BehaviorIngest.make(): ... experiment.TrialNote')
        experiment.TrialNote().insert(rows['trial_note'],
                                      ignore_extra_fields=True)

        log.info('BehaviorIngest.make(): ... experiment.TrialEvent')
        experiment.TrialEvent().insert(rows['trial_event'],
                                       ignore_extra_fields=True)

        log.info('BehaviorIngest.make(): ... experiment.ActionEvent')
        experiment.ActionEvent().insert(rows['action_event'],
                                        ignore_extra_fields=True)

        log.info('BehaviorIngest.make(): saving ingest {d}'.format(d=key))
        self.insert1(key, ignore_extra_fields=True)

        BehaviorIngest.BehaviorFile().insert(
            (dict(key, behavior_file=f.split(root)[1]) for f in matches),
            ignore_extra_fields=True)
コード例 #21
0
def populatebehavior_core(IDs = None):
    if IDs:
        print('subject started:')
        print(IDs.keys())
        print(IDs.values())
        
    rigpath_1 = 'E:/Projects/Ablation/datajoint/Behavior'
    
    #df_surgery = pd.read_csv(dj.config['locations.metadata']+'Surgery.csv')
    if IDs == None:
        IDs = {k: v for k, v in zip(*lab.WaterRestriction().fetch('water_restriction_number', 'subject_id'))}   

    for subject_now,subject_id_now in zip(IDs.keys(),IDs.values()): # iterating over subjects
        print('subject: ',subject_now)
    # =============================================================================
    #         if drop_last_session_for_mice_in_training:
    #             delete_last_session_before_upload = True
    #         else:
    #             delete_last_session_before_upload = False
    #         #df_wr = online_notebook.fetch_water_restriction_metadata(subject_now)
    # =============================================================================
        try:
            df_wr = pd.read_csv(dj.config['locations.metadata_behavior']+subject_now+'.csv')
        except:
            print(subject_now + ' has no metadata available')
            df_wr = pd.DataFrame()
        for df_wr_row in df_wr.iterrows():
            date_now = df_wr_row[1].Date.replace('-','')
            print('subject: ',subject_now,'  date: ',date_now)
            session_date = datetime(int(date_now[0:4]),int(date_now[4:6]),int(date_now[6:8]))
            if len(experiment.Session() & 'subject_id = "'+str(subject_id_now)+'"' & 'session_date > "'+str(session_date)+'"') != 0: # if it is not the last
                print('session already imported, skipping: ' + str(session_date))
                dotheupload = False
            elif len(experiment.Session() & 'subject_id = "'+str(subject_id_now)+'"' & 'session_date = "'+str(session_date)+'"') != 0: # if it is the last
                dotheupload = False
            else: # reuploading new session that is not present on the server
                dotheupload = True
                
            # if dotheupload is True, meaning that there are new mat file hasn't been uploaded
            # => needs to find which mat file hasn't been uploaded
            
            if dotheupload:
                found = set()
                rigpath_2 = subject_now
                rigpath_3 = rigpath_1 + '/' + rigpath_2
                rigpath = pathlib.Path(rigpath_3)
                
                def buildrec(rigpath, root, f):
                    try:
                        fullpath = pathlib.Path(root, f)
                        subpath = fullpath.relative_to(rigpath)
                        fsplit = subpath.stem.split('_')
                        h2o = fsplit[0]
                        ymd = fsplit[-2:-1][0]
                        animal = IDs[h2o]
                        if ymd == date_now:
                            return {
                                    'subject_id': animal,
                                    'session_date': date(int(ymd[0:4]), int(ymd[4:6]), int(ymd[6:8])),
                                    'rig_data_path': rigpath.as_posix(),
                                    'subpath': subpath.as_posix(),
                                    }
                    except:
                        pass
                for root, dirs, files in os.walk(rigpath):
                    for f in files:
                        r = buildrec(rigpath, root, f)
                        if r:
                            found.add(r['subpath'])
                            file = r
                
                # now start insert data
            
                path = pathlib.Path(file['rig_data_path'], file['subpath'])
                mat = spio.loadmat(path, squeeze_me=True)
                SessionData = mat['SessionData'].flatten()
                            
                # session record key
                skey = {}
                skey['subject_id'] = file['subject_id']
                skey['session_date'] = file['session_date']
                skey['username'] = '******'
                #skey['rig'] = key['rig']
            
                trial = namedtuple(  # simple structure to track per-trial vars
                        'trial', ('ttype', 'settings', 'state_times',
                                  'state_names', 'state_data', 'event_data',
                                  'event_times', 'trial_start'))
            
                # parse session datetime
                session_datetime_str = str('').join((str(SessionData['Info'][0]['SessionDate']),' ', str(SessionData['Info'][0]['SessionStartTime_UTC'])))
                session_datetime = datetime.strptime(session_datetime_str, '%d-%b-%Y %H:%M:%S')
            
                AllTrialTypes = SessionData['TrialTypes'][0]
                AllTrialSettings = SessionData['TrialSettings'][0]
                AllTrialStarts = SessionData['TrialStartTimestamp'][0]
                AllTrialStarts = AllTrialStarts - AllTrialStarts[0]
            
                RawData = SessionData['RawData'][0].flatten()
                AllStateNames = RawData['OriginalStateNamesByNumber'][0]
                AllStateData = RawData['OriginalStateData'][0]
                AllEventData = RawData['OriginalEventData'][0]
                AllStateTimestamps = RawData['OriginalStateTimestamps'][0]
                AllEventTimestamps = RawData['OriginalEventTimestamps'][0]
            
                trials = list(zip(AllTrialTypes, AllTrialSettings,
                                  AllStateTimestamps, AllStateNames, AllStateData,
                                  AllEventData, AllEventTimestamps, AllTrialStarts))
                
                if not trials:
                    log.warning('skipping date {d}, no valid files'.format(d=date))
                    return    
                #
                # Trial data seems valid; synthesize session id & add session record
                # XXX: note - later breaks can result in Sessions without valid trials
                #
            
                assert skey['session_date'] == session_datetime.date()
                
                skey['session_date'] = session_datetime.date()
                #skey['session_time'] = session_datetime.time()
            
                if len(experiment.Session() & 'subject_id = "'+str(file['subject_id'])+'"' & 'session_date = "'+str(file['session_date'])+'"') == 0:
                    if len(experiment.Session() & 'subject_id = "'+str(file['subject_id'])+'"') == 0:
                        skey['session'] = 1
                    else:
                        skey['session'] = len((experiment.Session() & 'subject_id = "'+str(file['subject_id'])+'"').fetch()['session']) + 1
            
                #
                # Actually load the per-trial data
                #
                log.info('BehaviorIngest.make(): trial parsing phase')

                # lists of various records for batch-insert
                rows = {k: list() for k in ('trial', 'behavior_trial', 'trial_note',
                                        'trial_event', 'corrected_trial_event',
                                        'action_event')} #, 'photostim',
                                    #'photostim_location', 'photostim_trial',
                                    #'photostim_trial_event')}

                i = 0  # trial numbering starts at 1
                for t in trials:
                    t = trial(*t)  # convert list of items to a 'trial' structure
                    i += 1  # increment trial counter

                    log.debug('BehaviorIngest.make(): parsing trial {i}'.format(i=i))

                    states = {k: (v+1) for v, k in enumerate(t.state_names)}
                    required_states = ('PreSamplePeriod', 'SamplePeriod',
                                       'DelayPeriod', 'ResponseCue', 'StopLicking',
                                       'TrialEnd')
                
                    missing = list(k for k in required_states if k not in states)
                    if len(missing) and missing =='PreSamplePeriod':
                        log.warning('skipping trial {i}; missing {m}'.format(i=i, m=missing))
                        continue

                    gui = t.settings['GUI'].flatten()
                    if len(experiment.Session() & 'subject_id = "'+str(file['subject_id'])+'"' & 'session_date = "'+str(file['session_date'])+'"') == 0:
                        if len(experiment.Session() & 'subject_id = "'+str(file['subject_id'])+'"') == 0:
                            skey['session'] = 1
                        else:
                            skey['session'] = len((experiment.Session() & 'subject_id = "'+str(file['subject_id'])+'"').fetch()['session']) + 1
                
                    #
                    # Top-level 'Trial' record
                    #
                    protocol_type = gui['ProtocolType'][0]
                    tkey = dict(skey)
                    has_presample = 1
                    try:
                        startindex = np.where(t.state_data == states['PreSamplePeriod'])[0]
                        has_presample = 1
                    except:
                        startindex = np.where(t.state_data == states['SamplePeriod'])[0]
                        has_presample = 0
                
                    # should be only end of 1st StopLicking;
                    # rest of data is irrelevant w/r/t separately ingested ephys
                    endindex = np.where(t.state_data == states['StopLicking'])[0]
                    log.debug('states\n' + str(states))
                    log.debug('state_data\n' + str(t.state_data))
                    log.debug('startindex\n' + str(startindex))
                    log.debug('endindex\n' + str(endindex))
                
                    if not(len(startindex) and len(endindex)):
                        log.warning('skipping {}: start/end mismatch: {}/{}'.format(i, str(startindex), str(endindex)))
                        continue
                    
                    try:
                        tkey['trial'] = i
                        tkey['trial_uid'] = i
                        tkey['trial_start_time'] = t.trial_start
                        tkey['trial_stop_time'] = t.trial_start + t.state_times[endindex][0]
                    except IndexError:
                        log.warning('skipping {}: IndexError: {}/{} -> {}'.format(i, str(startindex), str(endindex), str(t.state_times)))
                        continue
                    
                    log.debug('tkey' + str(tkey))
                    rows['trial'].append(tkey)
                
                    #
                    # Specific BehaviorTrial information for this trial
                    #                              
                    
                    bkey = dict(tkey)
                    bkey['task'] = 'audio delay'  # hard-coded here
                    bkey['task_protocol'] = 1     # hard-coded here
                
                    # determine trial instruction
                    trial_instruction = 'left'    # hard-coded here

                    if gui['Reversal'][0] == 1:
                        if t.ttype == 1:
                            trial_instruction = 'left'
                        elif t.ttype == 0:
                            trial_instruction = 'right'
                        elif t.ttype == 2:
                            trial_instruction = 'catch_right_autowater'
                        elif t.ttype == 3:
                            trial_instruction = 'catch_left_autowater'
                        elif t.ttype == 4:
                            trial_instruction = 'catch_right_noDelay'
                        elif t.ttype == 5:
                            trial_instruction = 'catch_left_noDelay'    
                    elif gui['Reversal'][0] == 2:
                        if t.ttype == 1:
                            trial_instruction = 'right'
                        elif t.ttype == 0:
                            trial_instruction = 'left'
                        elif t.ttype == 2:
                            trial_instruction = 'catch_left_autowater'
                        elif t.ttype == 3:
                            trial_instruction = 'catch_right_autowater'
                        elif t.ttype == 4:
                            trial_instruction = 'catch_left_noDelay'
                        elif t.ttype == 5:
                            trial_instruction = 'catch_right_noDelay'
                
                    bkey['trial_instruction'] = trial_instruction
                    # determine early lick
                    early_lick = 'no early'
                    
                    if (protocol_type >= 5 and 'EarlyLickDelay' in states and np.any(t.state_data == states['EarlyLickDelay'])):
                        early_lick = 'early'
                    if (protocol_type >= 5 and ('EarlyLickSample' in states and np.any(t.state_data == states['EarlyLickSample']))):
                        early_lick = 'early'
                        
                    bkey['early_lick'] = early_lick
                
                    # determine outcome
                    outcome = 'ignore'
                    if ('Reward' in states and np.any(t.state_data == states['Reward'])):
                        outcome = 'hit'
                    elif ('TimeOut' in states and np.any(t.state_data == states['TimeOut'])):
                        outcome = 'miss'
                    elif ('NoResponse' in states and np.any(t.state_data == states['NoResponse'])):
                        outcome = 'ignore'    
                    bkey['outcome'] = outcome
                    rows['behavior_trial'].append(bkey)
                    
                    #
                    # Add 'protocol' note
                    #
                    nkey = dict(tkey)
                    nkey['trial_note_type'] = 'protocol #'
                    nkey['trial_note'] = str(protocol_type)
                    rows['trial_note'].append(nkey)

                    #
                    # Add 'autolearn' note
                    #
                    nkey = dict(tkey)
                    nkey['trial_note_type'] = 'autolearn'
                    nkey['trial_note'] = str(gui['Autolearn'][0])
                    rows['trial_note'].append(nkey)
                    
                    #
                    # Add 'bitcode' note
                    #
                    if 'randomID' in gui.dtype.names:
                        nkey = dict(tkey)
                        nkey['trial_note_type'] = 'bitcode'
                        nkey['trial_note'] = str(gui['randomID'][0])
                        rows['trial_note'].append(nkey)
               
                
                    #
                    # Add presample event
                    #
                    sampleindex = np.where(t.state_data == states['SamplePeriod'])[0]
                    
                    if has_presample == 1:
                        log.debug('BehaviorIngest.make(): presample')
                        ekey = dict(tkey)                    
    
                        ekey['trial_event_id'] = len(rows['trial_event'])
                        ekey['trial_event_type'] = 'presample'
                        ekey['trial_event_time'] = t.state_times[startindex][0]
                        ekey['duration'] = (t.state_times[sampleindex[0]]- t.state_times[startindex])[0]
    
                        if math.isnan(ekey['duration']):
                            log.debug('BehaviorIngest.make(): fixing presample duration')
                            ekey['duration'] = 0.0  # FIXDUR: lookup from previous trial
    
                        rows['trial_event'].append(ekey)
                
                    #
                    # Add other 'sample' events
                    #
    
                    log.debug('BehaviorIngest.make(): sample events')
    
                    last_dur = None
    
                    for s in sampleindex:  # in protocol > 6 ~-> n>1
                        # todo: batch events
                        ekey = dict(tkey)
                        ekey['trial_event_id'] = len(rows['trial_event'])
                        ekey['trial_event_type'] = 'sample'
                        ekey['trial_event_time'] = t.state_times[s]
                        ekey['duration'] = gui['SamplePeriod'][0]
    
                        if math.isnan(ekey['duration']) and last_dur is None:
                            log.warning('... trial {} bad duration, no last_edur'.format(i, last_dur))
                            ekey['duration'] = 0.0  # FIXDUR: cross-trial check
                            rows['corrected_trial_event'].append(ekey)
    
                        elif math.isnan(ekey['duration']) and last_dur is not None:
                            log.warning('... trial {} duration using last_edur {}'.format(i, last_dur))
                            ekey['duration'] = last_dur
                            rows['corrected_trial_event'].append(ekey)
    
                        else:
                            last_dur = ekey['duration']  # only track 'good' values.
    
                        rows['trial_event'].append(ekey)
                
                    #
                    # Add 'delay' events
                    #
    
                    log.debug('BehaviorIngest.make(): delay events')
    
                    last_dur = None
                    delayindex = np.where(t.state_data == states['DelayPeriod'])[0]
    
                    for d in delayindex:  # protocol > 6 ~-> n>1
                        ekey = dict(tkey)
                        ekey['trial_event_id'] = len(rows['trial_event'])
                        ekey['trial_event_type'] = 'delay'
                        ekey['trial_event_time'] = t.state_times[d]
                        ekey['duration'] = gui['DelayPeriod'][0]
    
                        if math.isnan(ekey['duration']) and last_dur is None:
                            log.warning('... {} bad duration, no last_edur'.format(i, last_dur))
                            ekey['duration'] = 0.0  # FIXDUR: cross-trial check
                            rows['corrected_trial_event'].append(ekey)
    
                        elif math.isnan(ekey['duration']) and last_dur is not None:
                            log.warning('... {} duration using last_edur {}'.format(i, last_dur))
                            ekey['duration'] = last_dur
                            rows['corrected_trial_event'].append(ekey)
    
                        else:
                            last_dur = ekey['duration']  # only track 'good' values.
    
                        log.debug('delay event duration: {}'.format(ekey['duration']))
                        rows['trial_event'].append(ekey)
                         
                    #
                    # Add 'go' event
                    #
                    log.debug('BehaviorIngest.make(): go')
    
                    ekey = dict(tkey)
                    responseindex = np.where(t.state_data == states['ResponseCue'])[0]
    
                    ekey['trial_event_id'] = len(rows['trial_event'])
                    ekey['trial_event_type'] = 'go'
                    ekey['trial_event_time'] = t.state_times[responseindex][0]
                    ekey['duration'] = gui['AnswerPeriod'][0]
    
                    if math.isnan(ekey['duration']):
                        log.debug('BehaviorIngest.make(): fixing go duration')
                        ekey['duration'] = 0.0  # FIXDUR: lookup from previous trials
                        rows['corrected_trial_event'].append(ekey)
    
                    rows['trial_event'].append(ekey)
                
                    #
                    # Add 'trialEnd' events
                    #

                    log.debug('BehaviorIngest.make(): trialend events')

                    last_dur = None
                    trialendindex = np.where(t.state_data == states['TrialEnd'])[0]

                    ekey = dict(tkey)
                    ekey['trial_event_id'] = len(rows['trial_event'])
                    ekey['trial_event_type'] = 'trialend'
                    ekey['trial_event_time'] = t.state_times[trialendindex][0]
                    ekey['duration'] = 0.0
    
                    rows['trial_event'].append(ekey)
                    
                    #
                    # Add lick events
                    #
                       
                    lickleft = np.where(t.event_data == 69)[0]
                    log.debug('... lickleft: {r}'.format(r=str(lickleft)))
    
                    action_event_count = len(rows['action_event'])
                    if len(lickleft):
                        [rows['action_event'].append(
                                dict(tkey, action_event_id=action_event_count+idx,
                                     action_event_type='left lick',
                                     action_event_time=t.event_times[l]))
                        for idx, l in enumerate(lickleft)]
    
                    lickright = np.where(t.event_data == 71)[0]
                    log.debug('... lickright: {r}'.format(r=str(lickright)))
    
                    action_event_count = len(rows['action_event'])
                    if len(lickright):
                        [rows['action_event'].append(
                                dict(tkey, action_event_id=action_event_count+idx,
                                     action_event_type='right lick',
                                     action_event_time=t.event_times[r]))
                        for idx, r in enumerate(lickright)]
                    
                    # end of trial loop..    
        
                    # Session Insertion                     
                    log.info('BehaviorIngest.make(): adding session record')
                    skey['session_date'] = df_wr_row[1].Date 
                    skey['rig'] = 'Old Recording rig'
                    skey['username']  = '******'
                    experiment.Session().insert1(skey,skip_duplicates=True)

                # Behavior Insertion                

                log.info('BehaviorIngest.make(): ... experiment.Session.Trial')
                experiment.SessionTrial().insert(
                        rows['trial'], ignore_extra_fields=True, allow_direct_insert=True)

                log.info('BehaviorIngest.make(): ... experiment.BehaviorTrial')
                experiment.BehaviorTrial().insert(
                        rows['behavior_trial'], ignore_extra_fields=True,
                        allow_direct_insert=True)

                log.info('BehaviorIngest.make(): ... experiment.TrialNote')
                experiment.TrialNote().insert(
                        rows['trial_note'], ignore_extra_fields=True,
                        allow_direct_insert=True)

                log.info('BehaviorIngest.make(): ... experiment.TrialEvent')
                experiment.TrialEvent().insert(
                        rows['trial_event'], ignore_extra_fields=True,
                        allow_direct_insert=True, skip_duplicates=True)
        
#        log.info('BehaviorIngest.make(): ... CorrectedTrialEvents')
#        BehaviorIngest().CorrectedTrialEvents().insert(
#            rows['corrected_trial_event'], ignore_extra_fields=True,
#            allow_direct_insert=True)

                log.info('BehaviorIngest.make(): ... experiment.ActionEvent')
                experiment.ActionEvent().insert(
                        rows['action_event'], ignore_extra_fields=True,
                        allow_direct_insert=True)
                            
#%% for ingest tracking                
                if IDs:
                    print('subject started:')
                    print(IDs.keys())
                    print(IDs.values())
                    
                rigpath_tracking_1 = 'E:/Projects/Ablation/datajoint/video/'
                rigpath_tracking_2 = subject_now
                VideoDate1 = str(df_wr_row[1].VideoDate)
                if len(VideoDate1)==5:
                    VideoDate = '0'+ VideoDate1
                elif len(VideoDate1)==7:
                    VideoDate = '0'+ VideoDate1
                rigpath_tracking_3 = rigpath_tracking_1 + rigpath_tracking_2 + '/' + rigpath_tracking_2 + '_'+ VideoDate + '_front'
                
                rigpath_tracking = pathlib.Path(rigpath_tracking_3)
                
                #df_surgery = pd.read_csv(dj.config['locations.metadata']+'Surgery.csv')
                if IDs == None:
                    IDs = {k: v for k, v in zip(*lab.WaterRestriction().fetch('water_restriction_number', 'subject_id'))}   
                
                h2o = subject_now
                session = df_wr_row[1].Date
                trials = (experiment.SessionTrial() & session).fetch('trial')
                
                log.info('got session: {} ({} trials)'.format(session, len(trials)))
                
                #sdate = session['session_date']
                #sdate_sml = date_now #"{}{:02d}{:02d}".format(sdate.year, sdate.month, sdate.day)

                paths = rigpath_tracking
                devices = tracking.TrackingDevice().fetch(as_dict=True)
                
                # paths like: <root>/<h2o>/YYYY-MM-DD/tracking
                tracking_files = []
                for d in (d for d in devices):
                    tdev = d['tracking_device']
                    tpos = d['tracking_position']
                    tdat = paths
                    log.info('checking {} for tracking data'.format(tdat))               
                    
                    
#                    if not tpath.exists():
#                        log.warning('tracking path {} n/a - skipping'.format(tpath))
#                        continue
#                    
#                    camtrial = '{}_{}_{}.txt'.format(h2o, sdate_sml, tpos)
#                    campath = tpath / camtrial
#                    
#                    log.info('trying camera position trial map: {}'.format(campath))
#                    
#                    if not campath.exists():
#                        log.info('skipping {} - does not exist'.format(campath))
#                        continue
#                    
#                    tmap = load_campath(campath)  # file:trial
#                    n_tmap = len(tmap)
#                    log.info('loading tracking data for {} trials'.format(n_tmap))

                    i = 0                    
                    VideoTrialNum = df_wr_row[1].VideoTrialNum
                    
                    #tpath = pathlib.Path(tdat, h2o, VideoDate, 'tracking')
                    ppp = list(range(0,VideoTrialNum))
                    for tt in reversed(range(VideoTrialNum)):  # load tracking for trial
                        
                        i += 1
#                        if i % 50 == 0:
#                            log.info('item {}/{}, trial #{} ({:.2f}%)'
#                                     .format(i, n_tmap, t, (i/n_tmap)*100))
#                        else:
#                            log.debug('item {}/{}, trial #{} ({:.2f}%)'
#                                      .format(i, n_tmap, t, (i/n_tmap)*100))
        
                        # ex: dl59_side_1-0000.csv / h2o_position_tn-0000.csv
                        tfile = '{}_{}_{}_{}-*.csv'.format(h2o, VideoDate ,tpos, tt)
                        tfull = list(tdat.glob(tfile))
                        if not tfull or len(tfull) > 1:
                            log.info('file mismatch: file: {} trial: ({})'.format(
                                tt, tfull))
                            continue
        
                        tfull = tfull[-1]
                        trk = load_tracking(tfull)
                        
        
                        recs = {}
                        
                        #key_source = experiment.Session - tracking.Tracking                        
                        rec_base = dict(trial=ppp[tt], tracking_device=tdev)
                        #print(rec_base)
                        for k in trk:
                            if k == 'samples':
                                recs['tracking'] = {
                                    'subject_id' : skey['subject_id'], 
                                    'session' : skey['session'],
                                    **rec_base,
                                    'tracking_samples': len(trk['samples']['ts']),
                                }
                                
                            else:
                                rec = dict(rec_base)
        
                                for attr in trk[k]:
                                    rec_key = '{}_{}'.format(k, attr)
                                    rec[rec_key] = np.array(trk[k][attr])
        
                                recs[k] = rec
                        
                        
                        tracking.Tracking.insert1(
                            recs['tracking'], allow_direct_insert=True)
                        
                        #if len(recs['nose']) > 3000:
                            #continue
                            
                        recs['nose'] = {
                                'subject_id' : skey['subject_id'], 
                                'session' : skey['session'],
                                **recs['nose'],
                                }
                        
                        #print(recs['nose']['nose_x'])
                        if 'nose' in recs:
                            tracking.Tracking.NoseTracking.insert1(
                                recs['nose'], allow_direct_insert=True)
                            
                        recs['tongue_mid'] = {
                                'subject_id' : skey['subject_id'], 
                                'session' : skey['session'],
                                **recs['tongue_mid'],
                                }
        
                        if 'tongue_mid' in recs:
                            tracking.Tracking.TongueTracking.insert1(
                                recs['tongue_mid'], allow_direct_insert=True)
                            
                        recs['jaw'] = {
                                'subject_id' : skey['subject_id'], 
                                'session' : skey['session'],
                                **recs['jaw'],
                                }
        
                        if 'jaw' in recs:
                            tracking.Tracking.JawTracking.insert1(
                                recs['jaw'], allow_direct_insert=True)
                        
                        recs['tongue_left'] = {
                                'subject_id' : skey['subject_id'], 
                                'session' : skey['session'],
                                **recs['tongue_left'],
                                }
        
                        if 'tongue_left' in recs:
                            tracking.Tracking.LeftTongueTracking.insert1(
                                recs['tongue_left'], allow_direct_insert=True)
                            
                        recs['tongue_right'] = {
                                'subject_id' : skey['subject_id'], 
                                'session' : skey['session'],
                                **recs['tongue_right'],
                                }
        
                        if 'tongue_right' in recs:
                            tracking.Tracking.RightTongueTracking.insert1(
                                recs['tongue_right'], allow_direct_insert=True)
#                            fmap = {'paw_left_x': 'left_paw_x',  # remap field names
#                                    'paw_left_y': 'left_paw_y',
#                                    'paw_left_likelihood': 'left_paw_likelihood'}
        
#                            tracking.Tracking.LeftPawTracking.insert1({
#                                **{k: v for k, v in recs['paw_left'].items()
#                                   if k not in fmap},
#                                **{fmap[k]: v for k, v in recs['paw_left'].items()
#                                   if k in fmap}}, allow_direct_insert=True)
                        
                        recs['right_lickport'] = {
                                'subject_id' : skey['subject_id'], 
                                'session' : skey['session'],
                                **recs['right_lickport'],
                                }
                        
                        if 'right_lickport' in recs:
                            tracking.Tracking.RightLickPortTracking.insert1(
                                recs['right_lickport'], allow_direct_insert=True)
#                            fmap = {'paw_right_x': 'right_paw_x',  # remap field names
#                                    'paw_right_y': 'right_paw_y',
#                                    'paw_right_likelihood': 'right_paw_likelihood'}
#        
#                            tracking.Tracking.RightPawTracking.insert1({
#                                **{k: v for k, v in recs['paw_right'].items()
#                                   if k not in fmap},
#                                **{fmap[k]: v for k, v in recs['paw_right'].items()
#                                   if k in fmap}}, allow_direct_insert=True)
                        
                        recs['left_lickport'] = {
                                'subject_id' : skey['subject_id'], 
                                'session' : skey['session'],
                                **recs['left_lickport'],
                                }
                        
                        if 'left_lickport' in recs:
                            tracking.Tracking.LeftLickPortTracking.insert1(
                                recs['left_lickport'], allow_direct_insert=True)
        
#                        tracking_files.append({**key, 'trial': tmap[t], 'tracking_device': tdev,
#                             'tracking_file': str(tfull.relative_to(tdat))})
#        
#                    log.info('... completed {}/{} items.'.format(i, n_tmap))
#        
#                self.insert1(key)
#                self.TrackingFile.insert(tracking_files)
#                   
                            
                        tracking.VideoFiducialsTrial.populate()
                        bottom_tongue.Camera_pixels.populate()
                        print('start!')               
                        bottom_tongue.VideoTongueTrial.populate()
                        sessiontrialdata={              'subject_id':skey['subject_id'],
                                                        'session':skey['session'],
                                                        'trial': tt
                                                        }
                        if len(bottom_tongue.VideoTongueTrial* experiment.Session & experiment.BehaviorTrial  & 'session_date = "'+str(file['session_date'])+'"' &{'trial':tt})==0:
                            print('trial couldn''t be exported, deleting trial')
                            print(tt)
                            dj.config['safemode'] = False
                            (experiment.SessionTrial()&sessiontrialdata).delete()
                            dj.config['safemode'] = True  
                        
                        
                log.info('... done.')
コード例 #22
0
ファイル: tracking.py プロジェクト: susuchen66/map-ephys
    def make(self, key):
        '''
        TrackingIngest .make() function
        '''
        log.info('TrackingIngest().make(): key: {k}'.format(k=key))

        h2o = (lab.WaterRestriction() & key).fetch1('water_restriction_number')
        session = (experiment.Session() & key).fetch1()
        trials = (experiment.SessionTrial() & session).fetch('trial')

        log.info('got session: {} ({} trials)'.format(session, len(trials)))

        sdate = session['session_date']
        sdate_iso = sdate.isoformat()  # YYYY-MM-DD
        sdate_sml = "{}{:02d}{:02d}".format(sdate.year, sdate.month, sdate.day)

        paths = TrackingDataPath.fetch(as_dict=True)
        devices = tracking.TrackingDevice().fetch(as_dict=True)

        # paths like: <root>/<h2o>/YYYY-MM-DD/tracking
        for p, d in ((p, d) for d in devices for p in paths):

            tdev = d['tracking_device']
            tpos = d['tracking_position']
            tdat = p['tracking_data_path']

            log.info('checking {} for tracking data'.format(tdat))

            tpath = pathlib.Path(tdat, h2o, sdate_iso, 'tracking')

            if not tpath.exists():
                log.warning('tracking path {} n/a - skipping'.format(tpath))
                continue

            camtrial = '{}_{}_{}.txt'.format(h2o, sdate_sml, tpos)
            campath = tpath / camtrial

            log.info('trying camera position trial map: {}'.format(campath))

            if not campath.exists():
                log.info('skipping {} - does not exist'.format(campath))
                continue

            tmap = self.load_campath(campath)

            n_tmap = len(tmap)
            log.info('loading tracking data for {} trials'.format(n_tmap))

            i = 0
            for t in tmap:  # load tracking for trial

                if tmap[t] not in trials:
                    log.warning('nonexistant trial {}.. skipping'.format(t))
                    continue

                i += 1
                if i % 50 == 0:
                    log.info('item {}/{}, trial #{} ({:.2f}%)'.format(
                        i, n_tmap, t, (i / n_tmap) * 100))
                else:
                    log.debug('item {}/{}, trial #{} ({:.2f}%)'.format(
                        i, n_tmap, t, (i / n_tmap) * 100))

                # ex: dl59_side_1-0000.csv / h2o_position_tn-0000.csv
                tfile = '{}_{}_{}-*.csv'.format(h2o, tpos, t)
                tfull = list(tpath.glob(tfile))

                if not tfull or len(tfull) > 1:
                    log.info('tracking file {} mismatch'.format(tfull))
                    continue

                tfull = tfull[-1]
                trk = self.load_tracking(tfull)

                recs = {}
                rec_base = dict(key, trial=tmap[t], tracking_device=tdev)

                for k in trk:
                    if k == 'samples':
                        recs['tracking'] = {
                            **rec_base,
                            'tracking_samples':
                            len(trk['samples']['ts']),
                        }
                    else:
                        rec = dict(rec_base)

                        for attr in trk[k]:
                            rec_key = '{}_{}'.format(k, attr)
                            rec[rec_key] = np.array(trk[k][attr])

                        recs[k] = rec

                tracking.Tracking.insert1(recs['tracking'],
                                          allow_direct_insert=True)

                tracking.Tracking.NoseTracking.insert1(
                    recs['nose'], allow_direct_insert=True)

                tracking.Tracking.TongueTracking.insert1(
                    recs['tongue'], allow_direct_insert=True)

                tracking.Tracking.JawTracking.insert1(recs['jaw'],
                                                      allow_direct_insert=True)

            log.info('... completed {}/{} items.'.format(i, n_tmap))
            log.info('... saving load record')

            self.insert1(key)

            log.info('... done.')
コード例 #23
0
def populateelphys():
    #%%
    df_subject_wr_sessions = pd.DataFrame(lab.WaterRestriction() *
                                          experiment.Session() *
                                          experiment.SessionDetails)
    df_subject_ids = pd.DataFrame(lab.Subject())
    if len(df_subject_wr_sessions) > 0:
        subject_names = df_subject_wr_sessions[
            'water_restriction_number'].unique()
        subject_names.sort()
    else:
        subject_names = list()
    subject_ids = df_subject_ids['subject_id'].unique()
    #%
    sumdata = list()
    basedir = Path(dj.config['locations.elphysdata_acq4'])
    for setup_dir in basedir.iterdir():
        setup_name = setup_dir.name
        sessions = np.sort(
            os.listdir(setup_dir)
        )  #configfile.readConfigFile(setup_dir.joinpath('.index'))
        for session_acq in sessions[::-1]:  #.keys():
            if session_acq != '.' and session_acq != 'log.txt':
                session_dir = setup_dir.joinpath(session_acq)
                try:
                    cells = configfile.readConfigFile(
                        session_dir.joinpath('.index'))
                except:  # if there is no file
                    cells = None
                if cells and 'WR_name/ID' in cells['.'].keys(
                ):  # it needs to have WRname
                    wrname_ephys = cells['.']['WR_name/ID']
                    wrname = None
                    for wrname_potential in subject_names:  # look for water restriction number
                        if wrname_potential.lower() in wrname_ephys.lower():
                            wrname = wrname_potential
                            subject_id = (df_subject_wr_sessions.loc[
                                df_subject_wr_sessions[
                                    'water_restriction_number'] == wrname,
                                'subject_id']).unique()[0]
                    if wrname == None:  # look for animal identifier:
                        for wrname_potential in subject_ids:  # look for water restriction number
                            if str(wrname_potential) in wrname_ephys.lower():
                                subject_id = wrname_potential
                                if len(df_subject_wr_sessions) > 0 and len(
                                    (df_subject_wr_sessions.loc[
                                        df_subject_wr_sessions['subject_id'] ==
                                        subject_id, 'water_restriction_number']
                                     ).unique()) > 0:
                                    wrname = (df_subject_wr_sessions.loc[
                                        df_subject_wr_sessions['subject_id'] ==
                                        subject_id, 'water_restriction_number']
                                              ).unique()[0]
                                else:
                                    wrname = 'no water restriction number for this mouse'

                    if wrname:
                        session_date = (
                            session_acq[0:session_acq.find('_')]).replace(
                                '.', '-')

                        print('animal: ' + str(subject_id) + '  -  ' +
                              wrname)  ##
                        if setup_name == 'Voltage_rig_1P':
                            setupname = 'Voltage-Imaging-1p'
                        else:
                            print('unkwnown setup, please add')
                            timer.wait(1000)
                        if 'experimenter' in cells['.'].keys():
                            username = cells['.']['experimenter']
                        else:
                            username = '******'
                            print(
                                'username not specified in acq4 file, assuming rozsam'
                            )
                        ### check if session already exists
                        sessiondata = {
                            'subject_id':
                            subject_id,  #(lab.WaterRestriction() & 'water_restriction_number = "'+df_behavior_session['subject'][0]+'"').fetch()[0]['subject_id'],
                            'session': np.nan,
                            'session_date': session_date,
                            'session_time':
                            np.nan,  #session_time.strftime('%H:%M:%S'),
                            'username': username,
                            'rig': setupname
                        }
                        for cell in cells.keys():
                            if cell != '.' and cell != 'log.txt':
                                ephisdata_cell = list()
                                sweepstarttimes = list()
                                cell_dir = session_dir.joinpath(cell)
                                serieses = configfile.readConfigFile(
                                    cell_dir.joinpath('.index'))
                                cellstarttime = datetime.datetime.fromtimestamp(
                                    serieses['.']['__timestamp__'])
                                for series in serieses.keys():
                                    if series != '.' and series != 'log.txt':
                                        series_dir = cell_dir.joinpath(series)
                                        sweeps = configfile.readConfigFile(
                                            series_dir.joinpath('.index'))
                                        if 'Clamp1.ma' in sweeps.keys():
                                            protocoltype = 'single sweep'
                                            sweepkeys = ['']
                                        else:
                                            protocoltype = 'multiple sweeps'
                                            sweepkeys = sweeps.keys()
                                        for sweep in sweepkeys:
                                            if sweep != '.' and '.txt' not in sweep and '.ma' not in sweep:
                                                sweep_dir = series_dir.joinpath(
                                                    sweep)
                                                sweepinfo = configfile.readConfigFile(
                                                    sweep_dir.joinpath(
                                                        '.index'))
                                                if sweep == '':
                                                    sweep = '0'
                                                for file in sweepinfo.keys():
                                                    if '.ma' in file:
                                                        try:  # old file version

                                                            #print('new file version')
                                                            #%
                                                            ephysfile = h5.File(
                                                                sweep_dir.
                                                                joinpath(file),
                                                                "r")
                                                            data = ephysfile[
                                                                'data'][()]
                                                            metadata_h5 = ephysfile[
                                                                'info']
                                                            metadata = read_h5f_metadata(
                                                                metadata_h5)
                                                            daqchannels = list(
                                                                metadata[2]
                                                                ['DAQ'].keys())
                                                            sweepstarttime = datetime.datetime.fromtimestamp(
                                                                metadata[2]
                                                                ['DAQ']
                                                                [daqchannels[
                                                                    0]]
                                                                ['startTime'])
                                                            relativetime = (
                                                                sweepstarttime
                                                                - cellstarttime
                                                            ).total_seconds()
                                                            if len(
                                                                    ephisdata_cell
                                                            ) > 0 and ephisdata_cell[
                                                                    -1]['sweepstarttime'] == sweepstarttime:
                                                                ephisdata = ephisdata_cell.pop(
                                                                )
                                                            else:
                                                                ephisdata = dict(
                                                                )
                                                            if 'primary' in daqchannels:  # ephys data
                                                                ephisdata[
                                                                    'V'] = data[
                                                                        1]
                                                                ephisdata[
                                                                    'stim'] = data[
                                                                        0]
                                                                ephisdata[
                                                                    'data'] = data
                                                                ephisdata[
                                                                    'metadata'] = metadata
                                                                ephisdata[
                                                                    'time'] = metadata[
                                                                        1]['values']
                                                                ephisdata[
                                                                    'relativetime'] = relativetime
                                                                ephisdata[
                                                                    'sweepstarttime'] = sweepstarttime
                                                                ephisdata[
                                                                    'series'] = series
                                                                ephisdata[
                                                                    'sweep'] = sweep
                                                                sweepstarttimes.append(
                                                                    sweepstarttime
                                                                )
                                                            else:  # other daq stuff
                                                                #%
                                                                for idx, channel in enumerate(
                                                                        metadata[
                                                                            0]
                                                                    ['cols']):
                                                                    channelname = channel[
                                                                        'name'].decode(
                                                                        )
                                                                    if channelname[
                                                                            0] == 'u':
                                                                        channelname = channelname[
                                                                            2:
                                                                            -1]
                                                                        if channelname in [
                                                                                'OrcaFlashExposure',
                                                                                'Temperature',
                                                                                'LED525',
                                                                                'FrameCommand',
                                                                                'NextFileTrigger'
                                                                        ]:
                                                                            ephisdata[
                                                                                channelname] = data[
                                                                                    idx]
                                                                            #print('{} added'.format(channelname))
                                                                        else:
                                                                            print(
                                                                                'waiting in the other daq'
                                                                            )
                                                                            timer.sleep(
                                                                                1000
                                                                            )
                                                            ephisdata_cell.append(
                                                                ephisdata)
                                                            #%

                                                        except:  # new file version
                                                            print(
                                                                'old version')
                                                            ephysfile = MetaArray(
                                                            )
                                                            ephysfile.readFile(
                                                                sweep_dir.
                                                                joinpath(file))
                                                            data = ephysfile.asarray(
                                                            )
                                                            metadata = ephysfile.infoCopy(
                                                            )
                                                            sweepstarttime = datetime.datetime.fromtimestamp(
                                                                metadata[2]
                                                                ['startTime'])
                                                            relativetime = (
                                                                sweepstarttime
                                                                - cellstarttime
                                                            ).total_seconds()
                                                            ephisdata = dict()
                                                            ephisdata[
                                                                'V'] = data[1]
                                                            ephisdata[
                                                                'stim'] = data[
                                                                    0]
                                                            ephisdata[
                                                                'data'] = data
                                                            ephisdata[
                                                                'metadata'] = metadata
                                                            ephisdata[
                                                                'time'] = metadata[
                                                                    1]['values']
                                                            ephisdata[
                                                                'relativetime'] = relativetime
                                                            ephisdata[
                                                                'sweepstarttime'] = sweepstarttime
                                                            ephisdata[
                                                                'series'] = series
                                                            ephisdata[
                                                                'sweep'] = sweep
                                                            sweepstarttimes.append(
                                                                sweepstarttime)
                                                            ephisdata_cell.append(
                                                                ephisdata)

    # ============================================================================
    #                             if wrname == 'FOR04':
    # =============================================================================
    # add session to DJ if not present
                                if len(ephisdata_cell) > 0:

                                    # =============================================================================
                                    #                                     print('waiting')
                                    #                                     timer.sleep(1000)
                                    # =============================================================================
                                    #%
                                    if len(experiment.Session()
                                           & 'subject_id = "' +
                                           str(sessiondata['subject_id']) + '"'
                                           & 'session_date = "' +
                                           str(sessiondata['session_date']) +
                                           '"') == 0:
                                        if len(experiment.Session()
                                               & 'subject_id = "' +
                                               str(sessiondata['subject_id']) +
                                               '"') == 0:
                                            sessiondata['session'] = 1
                                        else:
                                            sessiondata['session'] = len(
                                                (experiment.Session()
                                                 & 'subject_id = "' +
                                                 str(sessiondata['subject_id'])
                                                 + '"').fetch()['session']) + 1
                                        sessiondata['session_time'] = (
                                            sweepstarttimes[0]
                                        ).strftime(
                                            '%H:%M:%S'
                                        )  # the time of the first sweep will be the session time
                                        experiment.Session().insert1(
                                            sessiondata)
                                    #%
                                    session = (
                                        experiment.Session()
                                        & 'subject_id = "' +
                                        str(sessiondata['subject_id']) + '"'
                                        & 'session_date = "' +
                                        str(sessiondata['session_date']) +
                                        '"').fetch('session')[0]
                                    cell_number = int(cell[cell.find('_') +
                                                           1:])
                                    #add cell if not added already
                                    celldata = {
                                        'subject_id': subject_id,
                                        'session': session,
                                        'cell_number': cell_number,
                                    }
                                    #%
                                    if len(ephys_patch.Cell() & celldata
                                           ) == 0 or len(ephys_patch.Cell() *
                                                         ephys_patch.Sweep()
                                                         & celldata) < len(
                                                             ephisdata_cell):
                                        if len(ephys_patch.Cell() *
                                               ephys_patch.Sweep() & celldata
                                               ) < len(ephisdata_cell):
                                            print('finishing a recording:')
                                        else:
                                            print('adding new recording:')
                                        print(celldata)
                                        if 'type' in serieses['.'].keys():
                                            if serieses['.'][
                                                    'type'] == 'interneuron':
                                                celldata['cell_type'] = 'int'
                                            elif serieses['.'][
                                                    'type'] == 'unknown' or serieses[
                                                        '.']['type'] == 'fail':
                                                celldata[
                                                    'cell_type'] = 'unidentified'
                                            else:
                                                print('unhandled cell type!!')
                                                timer.sleep(1000)
                                        else:
                                            celldata[
                                                'cell_type'] = 'unidentified'
                                        celldata['cell_recording_start'] = (
                                            sweepstarttimes[0]
                                        ).strftime('%H:%M:%S')
                                        if 'depth' in serieses['.'].keys(
                                        ) and len(serieses['.']['depth']) > 0:
                                            celldata['depth'] = int(
                                                serieses['.']['depth'])
                                        else:
                                            celldata['depth'] = -1
                                        try:
                                            ephys_patch.Cell().insert1(
                                                celldata,
                                                allow_direct_insert=True)
                                        except dj.errors.DuplicateError:
                                            pass  #already uploaded
                                        if 'notes' in serieses['.'].keys():
                                            cellnotes = serieses['.']['notes']
                                        else:
                                            cellnotes = ''
                                        cellnotesdata = {
                                            'subject_id': subject_id,
                                            'session': session,
                                            'cell_number': cell_number,
                                            'notes': cellnotes
                                        }
                                        try:
                                            ephys_patch.CellNotes().insert1(
                                                cellnotesdata,
                                                allow_direct_insert=True)
                                        except dj.errors.DuplicateError:
                                            pass  #already uploaded

                                        #%
                                        for i, ephisdata in enumerate(
                                                ephisdata_cell):

                                            #%
                                            sweep_number = i
                                            print('sweep {}'.format(
                                                sweep_number))
                                            sweep_data = {
                                                'subject_id':
                                                subject_id,
                                                'session':
                                                session,
                                                'cell_number':
                                                cell_number,
                                                'sweep_number':
                                                sweep_number,
                                                'sweep_start_time':
                                                (ephisdata['sweepstarttime'] -
                                                 sweepstarttimes[0]
                                                 ).total_seconds(),
                                                'sweep_end_time':
                                                (ephisdata['sweepstarttime'] -
                                                 sweepstarttimes[0]
                                                 ).total_seconds() +
                                                ephisdata['time'][-1],
                                                'protocol_name':
                                                ephisdata[
                                                    'series'],  #[:ephisdata['series'].find('_')],
                                                'protocol_sweep_number':
                                                int(ephisdata['sweep'])
                                            }

                                            if 'mode' in ephisdata['metadata'][
                                                    2]['ClampState']:  # old file version
                                                recmode = ephisdata[
                                                    'metadata'][2][
                                                        'ClampState']['mode']
                                            else:
                                                recmode = ephisdata[
                                                    'metadata'][2]['Protocol'][
                                                        'mode']

                                            if 'IC' in str(recmode):
                                                recording_mode = 'current clamp'
                                            else:
                                                print(
                                                    'unhandled recording mode, please act..'
                                                )
                                                timer.sleep(10000)

                                            channelnames = list()
                                            channelunits = list()
                                            for line_now in ephisdata[
                                                    'metadata'][0]['cols']:
                                                if type(line_now['name']
                                                        ) == bytes:
                                                    channelnames.append(
                                                        line_now['name'].
                                                        decode().strip("'"))
                                                    channelunits.append(
                                                        line_now['units'].
                                                        decode().strip("'"))
                                                else:
                                                    channelnames.append(
                                                        line_now['name'])
                                                    channelunits.append(
                                                        line_now['units'])
                                            commandidx = np.where(
                                                np.array(channelnames) ==
                                                'command')[0][0]
                                            dataidx = np.where(
                                                np.array(channelnames) ==
                                                'primary')[0][0]
                                            #%
                                            clampparams_data = ephisdata[
                                                'metadata'][2]['ClampState'][
                                                    'ClampParams'].copy()
                                            clampparams_data_new = dict()
                                            for clampparamkey in clampparams_data.keys(
                                            ):  #6004 is true for some reason.. changing it back to 1
                                                if type(clampparams_data[
                                                        clampparamkey]
                                                        ) == np.int32:
                                                    if clampparams_data[
                                                            clampparamkey] > 0:
                                                        clampparams_data[
                                                            clampparamkey] = int(
                                                                1)
                                                    else:
                                                        clampparams_data[
                                                            clampparamkey] = int(
                                                                0)
                                                else:
                                                    clampparams_data[
                                                        clampparamkey] = float(
                                                            clampparams_data[
                                                                clampparamkey])
                                                clampparams_data_new[
                                                    clampparamkey.lower(
                                                    )] = clampparams_data[
                                                        clampparamkey]
                                                #%
                                            sweepmetadata_data = {
                                                'subject_id':
                                                subject_id,
                                                'session':
                                                session,
                                                'cell_number':
                                                cell_number,
                                                'sweep_number':
                                                sweep_number,
                                                'recording_mode':
                                                recording_mode,
                                                'sample_rate':
                                                np.round(1 / np.median(
                                                    np.diff(
                                                        ephisdata['metadata']
                                                        [1]['values'])))
                                            }
                                            sweepmetadata_data.update(
                                                clampparams_data_new)
                                            sweepdata_data = {
                                                'subject_id':
                                                subject_id,
                                                'session':
                                                session,
                                                'cell_number':
                                                cell_number,
                                                'sweep_number':
                                                sweep_number,
                                                'response_trace':
                                                ephisdata['data'][dataidx, :],
                                                'response_units':
                                                ephisdata['metadata'][0]
                                                ['cols'][dataidx]['units']
                                            }

                                            sweepstimulus_data = {
                                                'subject_id':
                                                subject_id,
                                                'session':
                                                session,
                                                'cell_number':
                                                cell_number,
                                                'sweep_number':
                                                sweep_number,
                                                'stimulus_trace':
                                                ephisdata['data'][
                                                    commandidx, :],
                                                'stimulus_units':
                                                ephisdata['metadata'][0]
                                                ['cols'][commandidx]['units']
                                            }
                                            #print('waiting')
                                            #timer.sleep(10000)
                                            try:
                                                ephys_patch.Sweep().insert1(
                                                    sweep_data,
                                                    allow_direct_insert=True)
                                            except dj.errors.DuplicateError:
                                                pass  #already uploaded
                                            try:  # maybe it's a duplicate..
                                                ephys_patch.ClampParams(
                                                ).insert1(
                                                    clampparams_data_new,
                                                    allow_direct_insert=True)
                                            except dj.errors.DuplicateError:
                                                pass  #already uploaded
                                            try:
                                                ephys_patch.SweepMetadata(
                                                ).insert1(
                                                    sweepmetadata_data,
                                                    allow_direct_insert=True)
                                            except dj.errors.DuplicateError:
                                                pass  #already uploaded
                                            try:
                                                ephys_patch.SweepResponse(
                                                ).insert1(
                                                    sweepdata_data,
                                                    allow_direct_insert=True)
                                            except dj.errors.DuplicateError:
                                                pass  #already uploaded
                                            try:
                                                ephys_patch.SweepStimulus(
                                                ).insert1(
                                                    sweepstimulus_data,
                                                    allow_direct_insert=True)
                                            except dj.errors.DuplicateError:
                                                pass  #already uploaded
                                            #%
                                            if 'OrcaFlashExposure' in ephisdata.keys(
                                            ):
                                                sweepimagingexposuredata = {
                                                    'subject_id':
                                                    subject_id,
                                                    'session':
                                                    session,
                                                    'cell_number':
                                                    cell_number,
                                                    'sweep_number':
                                                    sweep_number,
                                                    'imaging_exposure_trace':
                                                    ephisdata[
                                                        'OrcaFlashExposure']
                                                }
                                                try:
                                                    ephys_patch.SweepImagingExposure(
                                                    ).insert1(
                                                        sweepimagingexposuredata,
                                                        allow_direct_insert=True
                                                    )
                                                except dj.errors.DuplicateError:
                                                    pass  #already uploaded
                                            if 'Temperature' in ephisdata.keys(
                                            ):
                                                sweeptemperaturedata = {
                                                    'subject_id':
                                                    subject_id,
                                                    'session':
                                                    session,
                                                    'cell_number':
                                                    cell_number,
                                                    'sweep_number':
                                                    sweep_number,
                                                    'temperature_trace':
                                                    ephisdata['Temperature'] *
                                                    10,
                                                    'temperature_units':
                                                    'degC'
                                                }
                                                try:
                                                    ephys_patch.SweepTemperature(
                                                    ).insert1(
                                                        sweeptemperaturedata,
                                                        allow_direct_insert=True
                                                    )
                                                except dj.errors.DuplicateError:
                                                    pass  #already uploaded
                                            if 'LED525' in ephisdata.keys():
                                                sweepLEDdata = {
                                                    'subject_id':
                                                    subject_id,
                                                    'session':
                                                    session,
                                                    'cell_number':
                                                    cell_number,
                                                    'sweep_number':
                                                    sweep_number,
                                                    'imaging_led_trace':
                                                    ephisdata['LED525']
                                                }
                                                try:
                                                    ephys_patch.SweepLED(
                                                    ).insert1(
                                                        sweepLEDdata,
                                                        allow_direct_insert=True
                                                    )
                                                except dj.errors.DuplicateError:
                                                    pass  #already uploaded
コード例 #24
0
def model_and_populate_a_session(subject_now, subject_id, session):
    print('session ' + str(session))
    p_reward_L, p_reward_R, n_trials = foraging_model.generate_block_structure(
        n_trials_base=80,
        n_trials_sd=10,
        blocknum=8,
        reward_ratio_pairs=np.array([[.4, .05], [.3857, .0643], [.3375, .1125],
                                     [.225, .225]]))
    if subject_now == 'leaky3t5it30h':
        rewards, choices = foraging_model.run_task(
            p_reward_L,
            p_reward_R,
            n_trials,
            unchosen_rewards_to_keep=1,
            subject='clever',
            min_rewardnum=30,
            filter_tau_fast=3,
            filter_tau_slow=100,
            filter_tau_slow_amplitude=00.00,
            softmax_temperature=5,
            plot=False)
    elif subject_now == 'leaky3t3it30h':
        rewards, choices = foraging_model.run_task(p_reward_L,
                                                   p_reward_R,
                                                   n_trials,
                                                   unchosen_rewards_to_keep=1,
                                                   subject='clever',
                                                   min_rewardnum=30,
                                                   filter_tau_fast=3,
                                                   filter_tau_slow=100,
                                                   filter_tau_slow_amplitude=0,
                                                   softmax_temperature=3,
                                                   plot=False)
    elif subject_now == 'W-St-L-Sw':
        rewards, choices = foraging_model.run_task(
            p_reward_L,
            p_reward_R,
            n_trials,
            unchosen_rewards_to_keep=1,
            subject='win_stay-loose_switch',
            min_rewardnum=3,
            filter_tau_fast=3,
            filter_tau_slow=100,
            filter_tau_slow_amplitude=00.01,
            plot=False)
    elif subject_now == 'W-St-L-Rnd':
        rewards, choices = foraging_model.run_task(
            p_reward_L,
            p_reward_R,
            n_trials,
            unchosen_rewards_to_keep=1,
            subject='win_stay-loose_random',
            min_rewardnum=3,
            filter_tau_fast=3,
            filter_tau_slow=100,
            filter_tau_slow_amplitude=00.01,
            filter_constant=.05,
            plot=False)
    elif subject_now == 'leaky3t.05c15h':
        rewards, choices = foraging_model.run_task(
            p_reward_L,
            p_reward_R,
            n_trials,
            unchosen_rewards_to_keep=1,
            subject='clever',
            min_rewardnum=15,
            filter_tau_fast=3,
            filter_tau_slow=100,
            filter_tau_slow_amplitude=00.0,
            filter_constant=.05,
            plot=False)
    elif subject_now == 'leaky3t.05c5h':
        rewards, choices = foraging_model.run_task(
            p_reward_L,
            p_reward_R,
            n_trials,
            unchosen_rewards_to_keep=1,
            subject='clever',
            min_rewardnum=5,
            filter_tau_fast=3,
            filter_tau_slow=100,
            filter_tau_slow_amplitude=00.0,
            filter_constant=.05,
            plot=False)
    elif subject_now == 'leaky3t.05c30h':
        rewards, choices = foraging_model.run_task(
            p_reward_L,
            p_reward_R,
            n_trials,
            unchosen_rewards_to_keep=1,
            subject='clever',
            min_rewardnum=30,
            filter_tau_fast=3,
            filter_tau_slow=100,
            filter_tau_slow_amplitude=00.0,
            filter_constant=.05,
            plot=False)
    elif subject_now == 'cheater':
        rewards, choices = foraging_model.run_task(
            p_reward_L,
            p_reward_R,
            n_trials,
            unchosen_rewards_to_keep=1,
            subject='perfect',
            min_rewardnum=30,
            filter_tau_fast=3,
            filter_tau_slow=100,
            filter_tau_slow_amplitude=00.0,
            filter_constant=.05,
            plot=False)

    else:
        print('unknown model')
    sessiondata = {
        'subject_id': subject_id,
        'session': session,
        'session_date': datetime.now().strftime('%Y-%m-%d'),
        'session_time': datetime.now().strftime('%H:%M:%S'),
        'username': experimenter,
        'rig': setupname
    }
    experiment.Session().insert1(sessiondata)
    trialssofar = 0
    columns = [
        'subject_id', 'session', 'block', 'block_uid', 'block_start_time',
        'p_reward_left', 'p_reward_right'
    ]
    df_sessionblockdata = pd.DataFrame(data=np.zeros(
        (len(p_reward_L), len(columns))),
                                       columns=columns)
    for blocknum, (p_L, p_R, trialnum) in enumerate(
            zip(p_reward_L, p_reward_R, n_trials), 1):
        df_sessionblockdata.loc[blocknum - 1, 'subject_id'] = subject_id
        df_sessionblockdata.loc[blocknum - 1, 'session'] = session
        df_sessionblockdata.loc[blocknum - 1, 'block'] = blocknum
        df_sessionblockdata.loc[blocknum - 1, 'block_uid'] = blocknum
        df_sessionblockdata.loc[blocknum - 1, 'block_start_time'] = trialssofar
        df_sessionblockdata.loc[blocknum - 1, 'p_reward_left'] = p_L
        df_sessionblockdata.loc[blocknum - 1, 'p_reward_right'] = p_R
        trialssofar += trialnum
    experiment.SessionBlock().insert(
        df_sessionblockdata.to_records(index=False), allow_direct_insert=True)
    columns_sessiontrial = [
        'subject_id', 'session', 'trial', 'trial_uid', 'trial_start_time',
        'trial_stop_time'
    ]
    df_sessiontrialdata = pd.DataFrame(data=np.zeros(
        (len(rewards), len(columns_sessiontrial))),
                                       columns=columns_sessiontrial)
    columns_behaviortrial = [
        'subject_id', 'session', 'trial', 'task', 'task_protocol',
        'trial_choice', 'early_lick', 'outcome', 'block'
    ]
    df_behaviortrialdata = pd.DataFrame(data=np.zeros(
        (len(rewards), len(columns_behaviortrial))),
                                        columns=columns_behaviortrial)
    for trialnum, (reward, choice) in enumerate(zip(rewards, choices), 1):
        df_sessiontrialdata.loc[trialnum - 1, 'subject_id'] = subject_id
        df_sessiontrialdata.loc[trialnum - 1, 'session'] = session
        df_sessiontrialdata.loc[trialnum - 1, 'trial'] = trialnum
        df_sessiontrialdata.loc[trialnum - 1, 'trial_uid'] = trialnum
        df_sessiontrialdata.loc[trialnum - 1,
                                'trial_start_time'] = trialnum - .9
        df_sessiontrialdata.loc[trialnum - 1,
                                'trial_stop_time'] = trialnum - .1

        #% outcome
        if reward:
            outcome = 'hit'
        else:
            outcome = 'miss'
        if choice == 1:
            trial_choice = 'right'
        else:
            trial_choice = 'left'
        task = 'foraging'
        task_protocol = 10
        df_behaviortrialdata.loc[trialnum - 1, 'subject_id'] = subject_id
        df_behaviortrialdata.loc[trialnum - 1, 'session'] = session
        df_behaviortrialdata.loc[trialnum - 1, 'trial'] = trialnum
        df_behaviortrialdata.loc[trialnum - 1, 'task'] = task
        df_behaviortrialdata.loc[trialnum - 1, 'task_protocol'] = task_protocol
        df_behaviortrialdata.loc[trialnum - 1, 'trial_choice'] = trial_choice
        df_behaviortrialdata.loc[trialnum - 1, 'early_lick'] = 'no early'
        df_behaviortrialdata.loc[trialnum - 1, 'outcome'] = outcome
        df_behaviortrialdata.loc[trialnum - 1, 'block'] = np.argmax(
            np.cumsum(n_trials) >= trialnum) + 1
    experiment.SessionTrial().insert(
        df_sessiontrialdata.to_records(index=False), allow_direct_insert=True)
    experiment.BehaviorTrial().insert(
        df_behaviortrialdata.to_records(index=False), allow_direct_insert=True)
コード例 #25
0
    def make(self, key):
        '''
        TrackingIngest .make() function
        '''
        log.info('TrackingIngest().make(): key: {k}'.format(k=key))

        h2o = (lab.WaterRestriction() & key).fetch1('water_restriction_number')
        session = (experiment.Session() & key).fetch1()
        trials = (experiment.SessionTrial() & session).fetch('trial')

        log.info('got session: {} ({} trials)'.format(session, len(trials)))

        sdate = session['session_date']
        sdate_sml = "{}{:02d}{:02d}".format(sdate.year, sdate.month, sdate.day)

        paths = get_tracking_paths()
        devices = tracking.TrackingDevice().fetch(as_dict=True)

        # paths like: <root>/<h2o>/YYYY-MM-DD/tracking
        tracking_files = []
        for p, d in ((p, d) for d in devices for p in paths):

            tdev = d['tracking_device']
            tpos = d['tracking_position']
            tdat = p[-1]

            log.info('checking {} for tracking data'.format(tdat))

            tpath = pathlib.Path(tdat, h2o, sdate.strftime('%Y%m%d'),
                                 'tracking')

            if not tpath.exists():
                log.warning('tracking path {} n/a - skipping'.format(tpath))
                continue

            camtrial = '{}_{}_{}.txt'.format(h2o, sdate_sml, tpos)
            campath = tpath / camtrial

            log.info('trying camera position trial map: {}'.format(campath))

            if not campath.exists():
                log.info('skipping {} - does not exist'.format(campath))
                continue

            tmap = self.load_campath(campath)  # file:trial

            n_tmap = len(tmap)
            log.info('loading tracking data for {} trials'.format(n_tmap))

            i = 0
            for t in tmap:  # load tracking for trial
                if tmap[t] not in trials:
                    log.warning('nonexistant trial {}.. skipping'.format(t))
                    continue

                i += 1
                if i % 50 == 0:
                    log.info('item {}/{}, trial #{} ({:.2f}%)'.format(
                        i, n_tmap, t, (i / n_tmap) * 100))
                else:
                    log.debug('item {}/{}, trial #{} ({:.2f}%)'.format(
                        i, n_tmap, t, (i / n_tmap) * 100))

                # ex: dl59_side_1-0000.csv / h2o_position_tn-0000.csv
                tfile = '{}_{}_{}-*.csv'.format(h2o, tpos, t)
                tfull = list(tpath.glob(tfile))

                if not tfull or len(tfull) > 1:
                    log.info('file mismatch: file: {} trial: {} ({})'.format(
                        t, tmap[t], tfull))
                    continue

                tfull = tfull[-1]
                trk = self.load_tracking(tfull)

                recs = {}
                rec_base = dict(key, trial=tmap[t], tracking_device=tdev)

                for k in trk:
                    if k == 'samples':
                        recs['tracking'] = {
                            **rec_base,
                            'tracking_samples':
                            len(trk['samples']['ts']),
                        }
                    else:
                        rec = dict(rec_base)

                        for attr in trk[k]:
                            rec_key = '{}_{}'.format(k, attr)
                            rec[rec_key] = np.array(trk[k][attr])

                        recs[k] = rec

                tracking.Tracking.insert1(recs['tracking'],
                                          allow_direct_insert=True)

                if 'nose' in recs:
                    tracking.Tracking.NoseTracking.insert1(
                        recs['nose'], allow_direct_insert=True)

                if 'tongue' in recs:
                    tracking.Tracking.TongueTracking.insert1(
                        recs['tongue'], allow_direct_insert=True)

                if 'jaw' in recs:
                    tracking.Tracking.JawTracking.insert1(
                        recs['jaw'], allow_direct_insert=True)

                if 'paw_left' in recs:
                    fmap = {
                        'paw_left_x': 'left_paw_x',  # remap field names
                        'paw_left_y': 'left_paw_y',
                        'paw_left_likelihood': 'left_paw_likelihood'
                    }

                    tracking.Tracking.LeftPawTracking.insert1(
                        {
                            **{
                                k: v
                                for k, v in recs['paw_left'].items() if k not in fmap
                            },
                            **{
                                fmap[k]: v
                                for k, v in recs['paw_left'].items() if k in fmap
                            }
                        },
                        allow_direct_insert=True)

                if 'paw_right' in recs:
                    fmap = {
                        'paw_right_x': 'right_paw_x',  # remap field names
                        'paw_right_y': 'right_paw_y',
                        'paw_right_likelihood': 'right_paw_likelihood'
                    }

                    tracking.Tracking.RightPawTracking.insert1(
                        {
                            **{
                                k: v
                                for k, v in recs['paw_right'].items() if k not in fmap
                            },
                            **{
                                fmap[k]: v
                                for k, v in recs['paw_right'].items() if k in fmap
                            }
                        },
                        allow_direct_insert=True)

                tracking_files.append({
                    **key, 'trial':
                    tmap[t],
                    'tracking_device':
                    tdev,
                    'tracking_file':
                    str(tfull.relative_to(tdat))
                })

            log.info('... completed {}/{} items.'.format(i, n_tmap))

        self.insert1(key)
        self.TrackingFile.insert(tracking_files)

        log.info('... done.')
コード例 #26
0
key ={'subject_id':466771}
key['movie_number'] = 1 #3

key ={'subject_id':456462}
key['movie_number'] = 3 #3

key = (imaging_gt.CellMovieCorrespondance()&key).fetch(as_dict = True)[0]
key['roi_type'] = "VolPy"
sweep_numbers = key['sweep_numbers']
del key['sweep_numbers']

#%
motion_corr_vectors = np.asarray((imaging.MotionCorrection()*imaging.RegisteredMovie()&key&'motion_correction_method  = "VolPy"'&'motion_corr_description= "rigid motion correction done with VolPy"').fetch1('motion_corr_vectors'))
#% ophys related stuff
session_time, cell_recording_start = (experiment.Session()*ephys_patch.Cell()&key).fetch1('session_time','cell_recording_start')
first_movie_start_time =  np.min(np.asarray(((imaging.Movie()*imaging_gt.GroundTruthROI())&key).fetch('movie_start_time'),float))
first_movie_start_time_real = first_movie_start_time + session_time.total_seconds()
frame_times = (imaging.MovieFrameTimes()&key).fetch1('frame_times') - cell_recording_start.total_seconds() + session_time.total_seconds()
roi_dff,roi_f0,roi_spike_indices,framerate = (imaging.Movie()*imaging.ROI*imaging_gt.GroundTruthROI()&key).fetch1('roi_dff','roi_f0','roi_spike_indices','movie_frame_rate')
roi_spike_indices = roi_spike_indices-1
roi_f = (roi_dff*roi_f0)+roi_f0

xvals = frame_times-frame_times[0]
yvals = roi_f
out = scipy.optimize.curve_fit(lambda t,a,b,c,d,e: a*np.exp(-t/b) + c + d*np.exp(-t/e),  xvals,  yvals,bounds=((0,0,-np.inf,0,0),(np.inf,np.inf,np.inf,np.inf,np.inf)))
f0_fit_f = out[0][0]*np.exp(-xvals/out[0][1])+out[0][2] +out[0][3]*np.exp(-xvals/out[0][4])
dff_fit = (roi_f-f0_fit_f)/f0_fit_f


コード例 #27
0
ファイル: publication.py プロジェクト: nwtien/map-ephys
    def make(self, key):
        '''
        determine available files from local endpoint and publish
        (create database records and transfer to globus)
        '''

        # >>> list(key.keys())
        # ['subject_id', 'session', 'trial', 'electrode_group', 'globus_alias']

        log.debug(key)
        lep, lep_sub, lep_dir = GlobusStorageLocation().local_endpoint
        log.info('local_endpoint: {}:{} -> {}'.format(lep, lep_sub, lep_dir))

        # get session related information needed for filenames/records
        sinfo = ((lab.WaterRestriction
                  * lab.Subject.proj()
                  * experiment.Session()
                  * experiment.SessionTrial) & key).fetch1()

        h2o = sinfo['water_restriction_number']
        sdate = sinfo['session_date']
        eg = key['electrode_group']
        trial = key['trial']

        # build file locations:
        # subdir - common subdirectory for globus/native filesystem
        # fpat: base file pattern for this sessions files
        # fbase: filesystem base path for this sessions files
        # gbase: globus-url base path for this sessions files

        subdir = os.path.join(h2o, str(sdate), str(eg))
        fpat = '{}_{}_{}_g0_t{}'.format(h2o, sdate, eg, trial)
        fbase = os.path.join(lep_dir, subdir, fpat)
        gbase = '/'.join((h2o, str(sdate), str(eg), fpat))

        # check for existence of actual files & use to build xfer list
        log.debug('checking {}'.format(fbase))

        ffound = []
        ftypes = RawEphysFileTypes.contents
        for ft in ftypes:
            fname = '{}{}'.format(fbase, ft[1])
            gname = '{}{}'.format(gbase, ft[1])
            if not os.path.exists(fname):
                log.debug('... {}: not found'.format(fname))
                continue

            log.debug('... {}: found'.format(fname))
            ffound.append((ft, gname,))

        # if files are found, transfer and create publication schema records

        if not len(ffound):
            log.info('no files found for key')
            return

        log.info('found files for key: {}'.format([f[1] for f in ffound]))

        repname, rep, rep_sub = (GlobusStorageLocation() & key).fetch()[0]

        gsm = self.get_gsm()
        gsm.activate_endpoint(lep)  # XXX: cache this / prevent duplicate RPC?
        gsm.activate_endpoint(rep)  # XXX: cache this / prevent duplicate RPC?

        if not ArchivedRawEphysTrial & key:
            log.info('ArchivedRawEphysTrial.insert1()')
            ArchivedRawEphysTrial.insert1(key)

        ftmap = {'ap-30kHz': ArchivedRawEphysTrial.ArchivedApChannel,
                 'ap-30kHz-meta': ArchivedRawEphysTrial.ArchivedApMeta,
                 'lf-2.5kHz': ArchivedRawEphysTrial.ArchivedLfChannel,
                 'lf-2.5kHz-meta': ArchivedRawEphysTrial.ArchivedLfMeta}

        for ft, gname in ffound:  # XXX: transfer/insert could be batched
            ft_class = ftmap[ft[0]]
            if not ft_class & key:
                srcp = '{}:/{}/{}'.format(lep, lep_sub, gname)
                dstp = '{}:/{}/{}'.format(rep, rep_sub, gname)

                log.info('transferring {} to {}'.format(srcp, dstp))

                # XXX: check if exists 1st? (manually or via API copy-checksum)
                if not gsm.cp(srcp, dstp):
                    emsg = "couldn't transfer {} to {}".format(srcp, dstp)
                    log.error(emsg)
                    raise dj.DataJointError(emsg)

                log.info('ArchivedRawEphysTrial.{}.insert1()'
                         .format(ft_class.__name__))

                ft_class.insert1(key)