コード例 #1
0
    def make(self, key):
        water_res_num, sess_date = get_wr_sessdate(key)
        sess_dir = store_stage / water_res_num / sess_date
        sess_dir.mkdir(parents=True, exist_ok=True)

        # -- Plotting --
        fig = plt.figure(figsize=(20, 12))
        fig.suptitle(f'{water_res_num}, session {key["session"]}')
        gs = GridSpec(5,
                      1,
                      wspace=0.4,
                      hspace=0.4,
                      bottom=0.07,
                      top=0.95,
                      left=0.1,
                      right=0.9)

        ax1 = fig.add_subplot(gs[0:3, :])
        ax2 = fig.add_subplot(gs[3, :])
        ax3 = fig.add_subplot(gs[4, :])
        ax1.get_shared_x_axes().join(ax1, ax2, ax3)

        # Plot settings
        plot_setting = {'left lick': 'red', 'right lick': 'blue'}

        # -- Get event times --
        key_subject_id_session = (
            experiment.Session() &
            (lab.WaterRestriction()
             & 'water_restriction_number="{}"'.format(water_res_num))
            & 'session="{}"'.format(key['session'])).fetch1("KEY")
        go_cue_times = (experiment.TrialEvent() & key_subject_id_session
                        & 'trial_event_type="go"').fetch(
                            'trial_event_time', order_by='trial').astype(float)
        lick_times = pd.DataFrame(
            (experiment.ActionEvent()
             & key_subject_id_session).fetch(order_by='trial'))

        trial_num = len(go_cue_times)
        all_trial_num = np.arange(1, trial_num + 1).tolist()
        all_trial_start = [[-x] for x in go_cue_times]
        all_lick = dict()
        for event_type in plot_setting:
            all_lick[event_type] = []
            for i, trial_start in enumerate(all_trial_start):
                all_lick[event_type].append(
                    (lick_times[(lick_times['trial'] == i + 1) &
                                (lick_times['action_event_type'] == event_type
                                 )]['action_event_time'].values.astype(float) +
                     trial_start).tolist())

        # -- All licking events (Ordered by trials) --
        ax1.plot([0, 0], [0, trial_num], 'k', lw=0.5)  # Aligned by go cue
        ax1.set(ylabel='Trial number', xlim=(-3, 3), xticks=[])

        # Batch plotting to speed up
        ax1.eventplot(lineoffsets=all_trial_num,
                      positions=all_trial_start,
                      color='k')  # Aligned by go cue
        for event_type in plot_setting:
            ax1.eventplot(lineoffsets=all_trial_num,
                          positions=all_lick[event_type],
                          color=plot_setting[event_type],
                          linewidth=2)  # Trial start

        # -- Histogram of all licks --
        for event_type in plot_setting:
            sns.histplot(np.hstack(all_lick[event_type]),
                         binwidth=0.01,
                         alpha=0.5,
                         ax=ax2,
                         color=plot_setting[event_type],
                         label=event_type)  # 10-ms window

        ymax_tmp = max(ax2.get_ylim())
        sns.histplot(-go_cue_times,
                     binwidth=0.01,
                     color='k',
                     ax=ax2,
                     label='trial start')  # 10-ms window
        ax2.axvline(x=0, color='k', lw=0.5)
        ax2.set(ylim=(0, ymax_tmp), xticks=[],
                title='All events')  # Fix the ylim of left and right licks
        ax2.legend()

        # -- Histogram of reaction time (first lick after go cue) --
        plot_setting = {'LEFT': 'red', 'RIGHT': 'blue'}
        for water_port in plot_setting:
            this_RT = (foraging_analysis.TrialStats() & key_subject_id_session
                       & (experiment.WaterPortChoice()
                          & 'water_port="{}"'.format(water_port))
                       ).fetch('reaction_time').astype(float)
            sns.histplot(this_RT,
                         binwidth=0.01,
                         alpha=0.5,
                         ax=ax3,
                         color=plot_setting[water_port],
                         label=water_port)  # 10-ms window
        ax3.axvline(x=0, color='k', lw=0.5)
        ax3.set(xlabel='Time to Go Cue (s)',
                title='First lick (reaction time)'
                )  # Fix the ylim of left and right licks
        ax3.legend()

        # ---- Save fig and insert ----
        fn_prefix = f'{water_res_num}_{sess_date}_'
        fig_dict = save_figs((fig, ), ('session_foraging_licking_psth', ),
                             sess_dir, fn_prefix)
        plt.close('all')
        self.insert1({**key, **fig_dict})
コード例 #2
0
    def make(self, key):
        log.info('BehaviorIngest.make(): key: {key}'.format(key=key))

        subject_id = key['subject_id']
        h2o = (lab.WaterRestriction() & {
            'subject_id': subject_id
        }).fetch1('water_restriction_number')

        ymd = key['session_date']
        datestr = ymd.strftime('%Y%m%d')
        log.info('h2o: {h2o}, date: {d}'.format(h2o=h2o, d=datestr))

        # session record key
        skey = {}
        skey['subject_id'] = subject_id
        skey['session_date'] = ymd
        skey['username'] = self.get_session_user()
        skey['rig'] = key['rig']

        # File paths conform to the pattern:
        # dl7/TW_autoTrain/Session Data/dl7_TW_autoTrain_20180104_132813.mat
        # which is, more generally:
        # {h2o}/{training_protocol}/Session Data/{h2o}_{training protocol}_{YYYYMMDD}_{HHMMSS}.mat

        path = pathlib.Path(key['rig_data_path'], key['subpath'])

        if experiment.Session() & skey:
            log.info("note: session exists for {h2o} on {d}".format(h2o=h2o,
                                                                    d=ymd))

        trial = namedtuple(  # simple structure to track per-trial vars
            'trial',
            ('ttype', 'stim', 'free', 'settings', 'state_times', 'state_names',
             'state_data', 'event_data', 'event_times', 'trial_start'))

        if os.stat(path).st_size / 1024 < 1000:
            log.info('skipping file {} - too small'.format(path))
            return

        log.debug('loading file {}'.format(path))

        mat = spio.loadmat(path, squeeze_me=True)
        SessionData = mat['SessionData'].flatten()

        # parse session datetime
        session_datetime_str = str('').join(
            (str(SessionData['Info'][0]['SessionDate']), ' ',
             str(SessionData['Info'][0]['SessionStartTime_UTC'])))

        session_datetime = datetime.strptime(session_datetime_str,
                                             '%d-%b-%Y %H:%M:%S')

        AllTrialTypes = SessionData['TrialTypes'][0]
        AllTrialSettings = SessionData['TrialSettings'][0]
        AllTrialStarts = SessionData['TrialStartTimestamp'][0]
        AllTrialStarts = AllTrialStarts - AllTrialStarts[0]  # real 1st trial

        RawData = SessionData['RawData'][0].flatten()
        AllStateNames = RawData['OriginalStateNamesByNumber'][0]
        AllStateData = RawData['OriginalStateData'][0]
        AllEventData = RawData['OriginalEventData'][0]
        AllStateTimestamps = RawData['OriginalStateTimestamps'][0]
        AllEventTimestamps = RawData['OriginalEventTimestamps'][0]

        # verify trial-related data arrays are all same length
        assert (all((x.shape[0] == AllStateTimestamps.shape[0]
                     for x in (AllTrialTypes, AllTrialSettings, AllStateNames,
                               AllStateData, AllEventData, AllEventTimestamps,
                               AllTrialStarts, AllTrialStarts))))

        # AllStimTrials optional special case
        if 'StimTrials' in SessionData.dtype.fields:
            log.debug('StimTrials detected in session - will include')
            AllStimTrials = SessionData['StimTrials'][0]
            assert (AllStimTrials.shape[0] == AllStateTimestamps.shape[0])
        else:
            log.debug('StimTrials not detected in session - will skip')
            AllStimTrials = np.array(
                [None for _ in enumerate(range(AllStateTimestamps.shape[0]))])

        # AllFreeTrials optional special case
        if 'FreeTrials' in SessionData.dtype.fields:
            log.debug('FreeTrials detected in session - will include')
            AllFreeTrials = SessionData['FreeTrials'][0]
            assert (AllFreeTrials.shape[0] == AllStateTimestamps.shape[0])
        else:
            log.debug('FreeTrials not detected in session - synthesizing')
            AllFreeTrials = np.zeros(AllStateTimestamps.shape[0],
                                     dtype=np.uint8)

        trials = list(
            zip(AllTrialTypes, AllStimTrials, AllFreeTrials, AllTrialSettings,
                AllStateTimestamps, AllStateNames, AllStateData, AllEventData,
                AllEventTimestamps, AllTrialStarts))

        if not trials:
            log.warning('skipping date {d}, no valid files'.format(d=date))
            return

        #
        # Trial data seems valid; synthesize session id & add session record
        # XXX: note - later breaks can result in Sessions without valid trials
        #

        assert skey['session_date'] == session_datetime.date()

        skey['session_date'] = session_datetime.date()
        skey['session_time'] = session_datetime.time()

        log.debug('synthesizing session ID')
        session = (dj.U().aggr(experiment.Session()
                               & {
                                   'subject_id': subject_id
                               },
                               n='max(session)').fetch1('n') or 0) + 1

        log.info('generated session id: {session}'.format(session=session))
        skey['session'] = session
        key = dict(key, **skey)

        #
        # Actually load the per-trial data
        #
        log.info('BehaviorIngest.make(): trial parsing phase')

        # lists of various records for batch-insert
        rows = {
            k: list()
            for k in ('trial', 'behavior_trial', 'trial_note', 'trial_event',
                      'corrected_trial_event', 'action_event', 'photostim',
                      'photostim_location', 'photostim_trial',
                      'photostim_trial_event')
        }

        i = 0  # trial numbering starts at 1
        for t in trials:

            #
            # Misc
            #

            t = trial(*t)  # convert list of items to a 'trial' structure
            i += 1  # increment trial counter

            log.debug('BehaviorIngest.make(): parsing trial {i}'.format(i=i))

            # covert state data names into a lookup dictionary
            #
            # names (seem to be? are?):
            #
            # Trigtrialstart, PreSamplePeriod, SamplePeriod, DelayPeriod
            # EarlyLickDelay, EarlyLickSample, ResponseCue, GiveLeftDrop
            # GiveRightDrop, GiveLeftDropShort, GiveRightDropShort
            # AnswerPeriod, Reward, RewardConsumption, NoResponse
            # TimeOut, StopLicking, StopLickingReturn, TrialEnd
            #

            states = {k: (v + 1) for v, k in enumerate(t.state_names)}
            required_states = ('PreSamplePeriod', 'SamplePeriod',
                               'DelayPeriod', 'ResponseCue', 'StopLicking',
                               'TrialEnd')

            missing = list(k for k in required_states if k not in states)

            if len(missing):
                log.warning('skipping trial {i}; missing {m}'.format(
                    i=i, m=missing))
                continue

            gui = t.settings['GUI'].flatten()

            # ProtocolType - only ingest protocol >= 3
            #
            # 1 Water-Valve-Calibration 2 Licking 3 Autoassist
            # 4 No autoassist 5 DelayEnforce 6 SampleEnforce 7 Fixed
            #

            if 'ProtocolType' not in gui.dtype.names:
                log.warning(
                    'skipping trial {i}; protocol undefined'.format(i=i))
                continue

            protocol_type = gui['ProtocolType'][0]
            if gui['ProtocolType'][0] < 3:
                log.warning('skipping trial {i}; protocol {n} < 3'.format(
                    i=i, n=gui['ProtocolType'][0]))
                continue

            #
            # Top-level 'Trial' record
            #

            tkey = dict(skey)
            startindex = np.where(t.state_data == states['PreSamplePeriod'])[0]

            # should be only end of 1st StopLicking;
            # rest of data is irrelevant w/r/t separately ingested ephys
            endindex = np.where(t.state_data == states['StopLicking'])[0]

            log.debug('states\n' + str(states))
            log.debug('state_data\n' + str(t.state_data))
            log.debug('startindex\n' + str(startindex))
            log.debug('endindex\n' + str(endindex))

            if not (len(startindex) and len(endindex)):
                log.warning('skipping {}: start/end mismatch: {}/{}'.format(
                    i, str(startindex), str(endindex)))
                continue

            try:
                tkey['trial'] = i
                tkey['trial_uid'] = i
                tkey['start_time'] = t.trial_start
                tkey['stop_time'] = t.trial_start + t.state_times[endindex][0]
            except IndexError:
                log.warning('skipping {}: IndexError: {}/{} -> {}'.format(
                    i, str(startindex), str(endindex), str(t.state_times)))
                continue

            log.debug('tkey' + str(tkey))
            rows['trial'].append(tkey)

            #
            # Specific BehaviorTrial information for this trial
            #

            bkey = dict(tkey)
            bkey['task'] = 'audio delay'  # hard-coded here
            bkey['task_protocol'] = 1  # hard-coded here

            # determine trial instruction
            trial_instruction = 'left'  # hard-coded here

            if gui['Reversal'][0] == 1:
                if t.ttype == 1:
                    trial_instruction = 'left'
                elif t.ttype == 0:
                    trial_instruction = 'right'
            elif gui['Reversal'][0] == 2:
                if t.ttype == 1:
                    trial_instruction = 'right'
                elif t.ttype == 0:
                    trial_instruction = 'left'

            bkey['trial_instruction'] = trial_instruction

            # determine early lick
            early_lick = 'no early'

            if (protocol_type >= 5 and 'EarlyLickDelay' in states
                    and np.any(t.state_data == states['EarlyLickDelay'])):
                early_lick = 'early'
            if (protocol_type >= 5 and
                ('EarlyLickSample' in states
                 and np.any(t.state_data == states['EarlyLickSample']))):
                early_lick = 'early'

            bkey['early_lick'] = early_lick

            # determine outcome
            outcome = 'ignore'

            if ('Reward' in states
                    and np.any(t.state_data == states['Reward'])):
                outcome = 'hit'
            elif ('TimeOut' in states
                  and np.any(t.state_data == states['TimeOut'])):
                outcome = 'miss'
            elif ('NoResponse' in states
                  and np.any(t.state_data == states['NoResponse'])):
                outcome = 'ignore'

            bkey['outcome'] = outcome

            # Determine free/autowater (Autowater 1 == enabled, 2 == disabled)
            bkey['auto_water'] = True if gui['Autowater'][0] == 1 else False
            bkey['free_water'] = t.free

            rows['behavior_trial'].append(bkey)

            #
            # Add 'protocol' note
            #
            nkey = dict(tkey)
            nkey['trial_note_type'] = 'protocol #'
            nkey['trial_note'] = str(protocol_type)
            rows['trial_note'].append(nkey)

            #
            # Add 'autolearn' note
            #
            nkey = dict(tkey)
            nkey['trial_note_type'] = 'autolearn'
            nkey['trial_note'] = str(gui['Autolearn'][0])
            rows['trial_note'].append(nkey)

            #
            # Add 'bitcode' note
            #
            if 'randomID' in gui.dtype.names:
                nkey = dict(tkey)
                nkey['trial_note_type'] = 'bitcode'
                nkey['trial_note'] = str(gui['randomID'][0])
                rows['trial_note'].append(nkey)

            #
            # Add presample event
            #
            log.debug('BehaviorIngest.make(): presample')

            ekey = dict(tkey)
            sampleindex = np.where(t.state_data == states['SamplePeriod'])[0]

            ekey['trial_event_id'] = len(rows['trial_event'])
            ekey['trial_event_type'] = 'presample'
            ekey['trial_event_time'] = t.state_times[startindex][0]
            ekey['duration'] = (t.state_times[sampleindex[0]] -
                                t.state_times[startindex])[0]

            if math.isnan(ekey['duration']):
                log.debug('BehaviorIngest.make(): fixing presample duration')
                ekey['duration'] = 0.0  # FIXDUR: lookup from previous trial

            rows['trial_event'].append(ekey)

            #
            # Add other 'sample' events
            #

            log.debug('BehaviorIngest.make(): sample events')

            last_dur = None

            for s in sampleindex:  # in protocol > 6 ~-> n>1
                # todo: batch events
                ekey = dict(tkey)
                ekey['trial_event_id'] = len(rows['trial_event'])
                ekey['trial_event_type'] = 'sample'
                ekey['trial_event_time'] = t.state_times[s]
                ekey['duration'] = gui['SamplePeriod'][0]

                if math.isnan(ekey['duration']) and last_dur is None:
                    log.warning(
                        '... trial {} bad duration, no last_edur'.format(
                            i, last_dur))
                    ekey['duration'] = 0.0  # FIXDUR: cross-trial check
                    rows['corrected_trial_event'].append(ekey)

                elif math.isnan(ekey['duration']) and last_dur is not None:
                    log.warning(
                        '... trial {} duration using last_edur {}'.format(
                            i, last_dur))
                    ekey['duration'] = last_dur
                    rows['corrected_trial_event'].append(ekey)

                else:
                    last_dur = ekey['duration']  # only track 'good' values.

                rows['trial_event'].append(ekey)

            #
            # Add 'delay' events
            #

            log.debug('BehaviorIngest.make(): delay events')

            last_dur = None
            delayindex = np.where(t.state_data == states['DelayPeriod'])[0]

            for d in delayindex:  # protocol > 6 ~-> n>1
                ekey = dict(tkey)
                ekey['trial_event_id'] = len(rows['trial_event'])
                ekey['trial_event_type'] = 'delay'
                ekey['trial_event_time'] = t.state_times[d]
                ekey['duration'] = gui['DelayPeriod'][0]

                if math.isnan(ekey['duration']) and last_dur is None:
                    log.warning('... {} bad duration, no last_edur'.format(
                        i, last_dur))
                    ekey['duration'] = 0.0  # FIXDUR: cross-trial check
                    rows['corrected_trial_event'].append(ekey)

                elif math.isnan(ekey['duration']) and last_dur is not None:
                    log.warning('... {} duration using last_edur {}'.format(
                        i, last_dur))
                    ekey['duration'] = last_dur
                    rows['corrected_trial_event'].append(ekey)

                else:
                    last_dur = ekey['duration']  # only track 'good' values.

                log.debug('delay event duration: {}'.format(ekey['duration']))
                rows['trial_event'].append(ekey)

            #
            # Add 'go' event
            #
            log.debug('BehaviorIngest.make(): go')

            ekey = dict(tkey)
            responseindex = np.where(t.state_data == states['ResponseCue'])[0]

            ekey['trial_event_id'] = len(rows['trial_event'])
            ekey['trial_event_type'] = 'go'
            ekey['trial_event_time'] = t.state_times[responseindex][0]
            ekey['duration'] = gui['AnswerPeriod'][0]

            if math.isnan(ekey['duration']):
                log.debug('BehaviorIngest.make(): fixing go duration')
                ekey['duration'] = 0.0  # FIXDUR: lookup from previous trials
                rows['corrected_trial_event'].append(ekey)

            rows['trial_event'].append(ekey)

            #
            # Add 'trialEnd' events
            #

            log.debug('BehaviorIngest.make(): trialend events')

            last_dur = None
            trialendindex = np.where(t.state_data == states['TrialEnd'])[0]

            ekey = dict(tkey)
            ekey['trial_event_id'] = len(rows['trial_event'])
            ekey['trial_event_type'] = 'trialend'
            ekey['trial_event_time'] = t.state_times[trialendindex][0]
            ekey['duration'] = 0.0

            rows['trial_event'].append(ekey)

            #
            # Add lick events
            #

            lickleft = np.where(t.event_data == 69)[0]
            log.debug('... lickleft: {r}'.format(r=str(lickleft)))

            action_event_count = len(rows['action_event'])
            if len(lickleft):
                [
                    rows['action_event'].append(
                        dict(tkey,
                             action_event_id=action_event_count + idx,
                             action_event_type='left lick',
                             action_event_time=t.event_times[l]))
                    for idx, l in enumerate(lickleft)
                ]

            lickright = np.where(t.event_data == 71)[0]
            log.debug('... lickright: {r}'.format(r=str(lickright)))

            action_event_count = len(rows['action_event'])
            if len(lickright):
                [
                    rows['action_event'].append(
                        dict(tkey,
                             action_event_id=action_event_count + idx,
                             action_event_type='right lick',
                             action_event_time=t.event_times[r]))
                    for idx, r in enumerate(lickright)
                ]

            #
            # Photostim Events
            #

            if t.stim:
                log.debug('BehaviorIngest.make(): t.stim == {}'.format(t.stim))
                rows['photostim_trial'].append(tkey)
                delay_period_idx = np.where(
                    t.state_data == states['DelayPeriod'])[0][0]
                rows['photostim_trial_event'].append(
                    dict(tkey,
                         photo_stim=t.stim,
                         photostim_event_id=len(rows['photostim_trial_event']),
                         photostim_event_time=t.state_times[delay_period_idx],
                         power=5.5))

            # end of trial loop.

        # Session Insertion

        log.info('BehaviorIngest.make(): adding session record')
        experiment.Session().insert1(skey)

        # Behavior Insertion

        log.info('BehaviorIngest.make(): bulk insert phase')

        log.info('BehaviorIngest.make(): saving ingest {d}'.format(d=key))
        self.insert1(key, ignore_extra_fields=True, allow_direct_insert=True)

        log.info('BehaviorIngest.make(): ... experiment.Session.Trial')
        experiment.SessionTrial().insert(rows['trial'],
                                         ignore_extra_fields=True,
                                         allow_direct_insert=True)

        log.info('BehaviorIngest.make(): ... experiment.BehaviorTrial')
        experiment.BehaviorTrial().insert(rows['behavior_trial'],
                                          ignore_extra_fields=True,
                                          allow_direct_insert=True)

        log.info('BehaviorIngest.make(): ... experiment.TrialNote')
        experiment.TrialNote().insert(rows['trial_note'],
                                      ignore_extra_fields=True,
                                      allow_direct_insert=True)

        log.info('BehaviorIngest.make(): ... experiment.TrialEvent')
        experiment.TrialEvent().insert(rows['trial_event'],
                                       ignore_extra_fields=True,
                                       allow_direct_insert=True,
                                       skip_duplicates=True)

        log.info('BehaviorIngest.make(): ... CorrectedTrialEvents')
        BehaviorIngest().CorrectedTrialEvents().insert(
            rows['corrected_trial_event'],
            ignore_extra_fields=True,
            allow_direct_insert=True)

        log.info('BehaviorIngest.make(): ... experiment.ActionEvent')
        experiment.ActionEvent().insert(rows['action_event'],
                                        ignore_extra_fields=True,
                                        allow_direct_insert=True)

        # Photostim Insertion

        photostim_ids = np.unique(
            [r['photo_stim'] for r in rows['photostim_trial_event']])

        unknown_photostims = np.setdiff1d(photostim_ids,
                                          list(photostims.keys()))

        if unknown_photostims:
            raise ValueError(
                'Unknown photostim protocol: {}'.format(unknown_photostims))

        if photostim_ids.size > 0:
            log.info('BehaviorIngest.make(): ... experiment.Photostim')
            for stim in photostim_ids:
                experiment.Photostim.insert1(dict(skey, **photostims[stim]),
                                             ignore_extra_fields=True)

                experiment.Photostim.PhotostimLocation.insert(
                    (dict(
                        skey, **loc, photo_stim=photostims[stim]['photo_stim'])
                     for loc in photostims[stim]['locations']),
                    ignore_extra_fields=True)

        log.info('BehaviorIngest.make(): ... experiment.PhotostimTrial')
        experiment.PhotostimTrial.insert(rows['photostim_trial'],
                                         ignore_extra_fields=True,
                                         allow_direct_insert=True)

        log.info('BehaviorIngest.make(): ... experiment.PhotostimTrialEvent')
        experiment.PhotostimEvent.insert(rows['photostim_trial_event'],
                                         ignore_extra_fields=True,
                                         allow_direct_insert=True)

        # Behavior Ingest Insertion

        log.info('BehaviorIngest.make(): ... BehaviorIngest.BehaviorFile')
        BehaviorIngest.BehaviorFile().insert1(dict(
            key, behavior_file=os.path.basename(key['subpath'])),
                                              ignore_extra_fields=True,
                                              allow_direct_insert=True)
コード例 #3
0
ファイル: export.py プロジェクト: Yi-61/map-ephys
def _export_recording(insert_key,
                      output_dir='./',
                      filename=None,
                      overwrite=False):
    '''
    Export a 'recording' (probe specific data + related events) to a file.

    Parameters:

      - insert_key: an ephys.ProbeInsertion.primary_key
        currently: {'subject_id', 'session', 'insertion_number'})

      - output_dir: directory to save the file at (default to be the current working directory)

      - filename: an optional output file path string. If not provided,
        filename will be autogenerated using the 'mkfilename'
        function.
    '''

    if filename is None:
        filename = mkfilename(insert_key)

    filepath = pathlib.Path(output_dir) / filename

    if filepath.exists() and not overwrite:
        print('{} already exists, skipping...'.format(filepath))
        return

    print('exporting {} to {}'.format(insert_key, filepath))

    print('fetching spike/behavior data')

    insertion = (ephys.ProbeInsertion.InsertionLocation & insert_key).fetch1()
    units = (ephys.Unit & insert_key).fetch()

    trial_spikes = (ephys.Unit.TrialSpikes
                    & insert_key).fetch(order_by='trial asc')

    behav = (experiment.BehaviorTrial & insert_key).fetch(order_by='trial asc')

    trials = behav['trial']

    exports = [
        'neuron_single_units', 'neuron_unit_info', 'behavior_report',
        'behavior_early_report', 'behavior_lick_times', 'task_trial_type',
        'task_stimulation', 'task_cue_time'
    ]

    edata = {k: None for k in exports}

    print('reshaping/processing for export')

    # neuron_single_units
    # -------------------

    # [[u0t0.spikes, ..., u0tN.spikes], ..., [uNt0.spikes, ..., uNtN.spikes]]
    print('... neuron_single_units:', end='')

    _su = defaultdict(list)

    ts = trial_spikes[['unit', 'trial', 'spike_times']]

    for u, t in ((u, t) for t in trials for u in units['unit']):
        ud = ts[np.logical_and(ts['unit'] == u, ts['trial'] == t)]
        if ud:
            _su[u].append(ud['spike_times'][0])
        else:
            _su[u].append(np.array([]))

    ndarray_object = np.empty((len(_su.keys()), 1), dtype=np.object)
    for idx, i in enumerate(sorted(_su.keys())):
        ndarray_object[idx, 0] = np.array(_su[i], ndmin=2).T

    edata['neuron_single_units'] = ndarray_object

    print('ok.')

    # neuron_unit_info
    # ----------------
    #
    # [[depth_in_um, cell_type, recording_location] ...]
    print('... neuron_unit_info:', end='')

    dv = float(insertion['depth']) if insertion['depth'] else np.nan
    loc = (ephys.ProbeInsertion & insert_key).aggr(
        ephys.ProbeInsertion.RecordableBrainRegion.proj(
            brain_region='CONCAT(hemisphere, " ", brain_area)'),
        brain_regions='GROUP_CONCAT(brain_region SEPARATOR ", ")').fetch1(
            'brain_regions')

    cell_types = {
        u['unit']: u['cell_type']
        for u in (ephys.UnitCellType & insert_key).fetch(as_dict=True)
    }

    _ui = []
    for u in units:
        typ = cell_types[u['unit']] if u['unit'] in cell_types else 'unknown'
        _ui.append([u['unit_posy'] + dv, typ, loc])

    edata['neuron_unit_info'] = np.array(_ui, dtype='O')

    print('ok.')

    # behavior_report
    # ---------------
    print('... behavior_report:', end='')

    behavior_report_map = {'hit': 1, 'miss': 0, 'ignore': 0}  # XXX: ignore ok?
    edata['behavior_report'] = np.array(
        [behavior_report_map[i] for i in behav['outcome']])

    print('ok.')

    # behavior_early_report
    # ---------------------
    print('... behavior_early_report:', end='')

    early_report_map = {'early': 1, 'no early': 0}
    edata['behavior_early_report'] = np.array(
        [early_report_map[i] for i in behav['early_lick']])

    print('ok.')

    # behavior_touch_times
    # --------------------

    behavior_touch_times = None  # NOQA no data (see ActionEventType())

    # behavior_lick_times
    # -------------------
    print('... behavior_lick_times:', end='')

    _lt = []
    licks = (experiment.ActionEvent() & insert_key
             & "action_event_type in ('left lick', 'right lick')").fetch()

    for t in trials:

        _lt.append([
            float(i) for i in  # decimal -> float
            licks[licks['trial'] == t]['action_event_time']
        ] if t in licks['trial'] else [])

    edata['behavior_lick_times'] = np.array(_lt)

    behavior_whisker_angle = None  # NOQA no data
    behavior_whisker_dist2pol = None  # NOQA no data

    print('ok.')

    # task_trial_type
    # ---------------
    print('... task_trial_type:', end='')

    task_trial_type_map = {'left': 'l', 'right': 'r'}
    edata['task_trial_type'] = np.array(
        [task_trial_type_map[i] for i in behav['trial_instruction']],
        dtype='O')

    print('ok.')

    # task_stimulation
    # ----------------
    print('... task_stimulation:', end='')

    _ts = []  # [[power, type, on-time, off-time], ...]

    photostim = (experiment.Photostim * experiment.PhotostimBrainRegion.proj(
        stim_brain_region='CONCAT(stim_laterality, " ", stim_brain_area)')
                 & insert_key).fetch()

    photostim_map = {}
    photostim_dat = {}
    photostim_keys = ['left ALM', 'right ALM', 'both ALM']
    photostim_vals = [1, 2, 6]

    # XXX: we don't detect duplicate presence of photostim_keys in data
    for fk, rk in zip(photostim_keys, photostim_vals):

        i = np.where(photostim['stim_brain_region'] == fk)[0][0]
        j = photostim[i]['photo_stim']
        photostim_map[j] = rk
        photostim_dat[j] = photostim[i]

    photostim_ev = (experiment.PhotostimEvent & insert_key).fetch()

    for t in trials:

        if t in photostim_ev['trial']:

            ev = photostim_ev[np.where(photostim_ev['trial'] == t)]
            ps = photostim_map[ev['photo_stim'][0]]
            pd = photostim_dat[ev['photo_stim'][0]]

            _ts.append([
                float(ev['power']), ps,
                float(ev['photostim_event_time']),
                float(ev['photostim_event_time'] + pd['duration'])
            ])

        else:
            _ts.append([0, math.nan, math.nan, math.nan])

    edata['task_stimulation'] = np.array(_ts)

    print('ok.')

    # task_pole_time
    # --------------

    task_pole_time = None  # NOQA no data

    # task_cue_time
    # -------------

    print('... task_cue_time:', end='')

    _tct = (experiment.TrialEvent()
            & {
                **insert_key, 'trial_event_type': 'go'
            }).fetch('trial_event_time')

    edata['task_cue_time'] = np.array([float(i) for i in _tct])

    print('ok.')

    # savemat
    # -------

    print('... saving to {}:'.format(filepath), end='')

    scio.savemat(filepath, edata)

    print('ok.')
コード例 #4
0
    def make(self, key):
        log.info('BehaviorIngest.make(): key: {key}'.format(key=key))
        rigpaths = [
            p for p in RigDataPath().fetch(order_by='rig_data_path')
            if 'RRig' in p['rig']
        ]  # change between TRig and RRig

        subject_id = key['subject_id']
        h2o = (lab.WaterRestriction() & {
            'subject_id': subject_id
        }).fetch1('water_restriction_number')
        date = key['session_date']
        datestr = date.strftime('%Y%m%d')
        log.debug('h2o: {h2o}, date: {d}'.format(h2o=h2o, d=datestr))

        # session record key
        skey = {}
        skey['subject_id'] = subject_id
        skey['session_date'] = date
        skey['username'] = '******'  # username has to be changed

        # e.g: dl7/TW_autoTrain/Session Data/dl7_TW_autoTrain_20180104_132813.mat
        #         # p.split('/foo/bar')[1]
        for rp in rigpaths:
            root = rp['rig_data_path']
            path = root
            path = os.path.join(path, h2o)
            #            path = os.path.join(path, 'TW_autoTrain')
            path = os.path.join(path, 'tw2')
            path = os.path.join(path, 'Session Data')
            path = os.path.join(
                #                path, '{h2o}_TW_autoTrain_{d}*.mat'.format(h2o=h2o, d=datestr)) # earlier program protocol
                path,
                '{h2o}_tw2_{d}*.mat'.format(
                    h2o=h2o, d=datestr))  # later program protocol

            log.debug('rigpath {p}'.format(p=path))

            matches = glob.glob(path)
            if len(matches):
                log.debug('found files, this is the rig')
                skey['rig'] = rp['rig']
                break
            else:
                log.info('no file matches found in {p}'.format(p=path))

        if not len(matches):
            log.warning('no file matches found for {h2o} / {d}'.format(
                h2o=h2o, d=datestr))
            return

        #
        # Find files & Check for split files
        # XXX: not checking rig.. 2+ sessions on 2+ rigs possible for date?
        #

        if len(matches) > 1:
            log.warning(
                'split session case detected for {h2o} on {date}'.format(
                    h2o=h2o, date=date))

        # session:date relationship is 1:1; skip if we have a session
        if experiment.Session() & skey:
            log.warning("Warning! session exists for {h2o} on {d}".format(
                h2o=h2o, d=date))
            return

        #
        # Extract trial data from file(s) & prepare trial loop
        #

        trials = zip()

        trial = namedtuple(  # simple structure to track per-trial vars
            'trial', ('ttype', 'settings', 'state_times', 'state_names',
                      'state_data', 'event_data', 'event_times'))

        for f in matches:

            if os.stat(f).st_size / 1024 < 100:
                log.info('skipping file {f} - too small'.format(f=f))
                continue

            mat = spio.loadmat(f, squeeze_me=True)
            SessionData = mat['SessionData'].flatten()

            AllTrialTypes = SessionData['TrialTypes'][0]
            AllTrialSettings = SessionData['TrialSettings'][0]

            RawData = SessionData['RawData'][0].flatten()
            AllStateNames = RawData['OriginalStateNamesByNumber'][0]
            AllStateData = RawData['OriginalStateData'][0]
            AllEventData = RawData['OriginalEventData'][0]
            AllStateTimestamps = RawData['OriginalStateTimestamps'][0]
            AllEventTimestamps = RawData['OriginalEventTimestamps'][0]

            # verify trial-related data arrays are all same length
            assert (all(
                (x.shape[0] == AllStateTimestamps.shape[0]
                 for x in (AllTrialTypes, AllTrialSettings, AllStateNames,
                           AllStateData, AllEventData, AllEventTimestamps))))

            z = zip(AllTrialTypes, AllTrialSettings, AllStateTimestamps,
                    AllStateNames, AllStateData, AllEventData,
                    AllEventTimestamps)

            trials = chain(trials, z)  # concatenate the files

        trials = list(trials)

        # all files were internally invalid or size < 100k
        if not trials:
            log.warning('skipping date {d}, no valid files'.format(d=date))

        #
        # Trial data seems valid; synthesize session id & add session record
        # XXX: note - later breaks can result in Sessions without valid trials
        #

        log.debug('synthesizing session ID')
        session = (dj.U().aggr(experiment.Session() & {
            'subject_id': subject_id
        },
                               n='max(session)').fetch1('n') or 0) + 1
        log.info('generated session id: {session}'.format(session=session))
        skey['session'] = session
        key = dict(key, **skey)

        log.debug('BehaviorIngest.make(): adding session record')
        experiment.Session().insert1(skey)

        #
        # Actually load the per-trial data
        #
        log.info('BehaviorIngest.make(): trial parsing phase')

        # lists of various records for batch-insert
        rows = {
            k: list()
            for k in ('trial', 'behavior_trial', 'trial_note', 'trial_event',
                      'action_event')
        }

        i = -1
        for t in trials:

            #
            # Misc
            #

            t = trial(*t)  # convert list of items to a 'trial' structure
            i += 1  # increment trial counter

            log.info('BehaviorIngest.make(): parsing trial {i}'.format(i=i))

            # covert state data names into a lookup dictionary
            #
            # names (seem to be? are?):
            #
            # Trigtrialstart
            # PreSamplePeriod
            # SamplePeriod
            # DelayPeriod
            # EarlyLickDelay
            # EarlyLickSample
            # ResponseCue
            # GiveLeftDrop
            # GiveRightDrop
            # GiveLeftDropShort
            # GiveRightDropShort
            # AnswerPeriod
            # Reward
            # RewardConsumption
            # NoResponse
            # TimeOut
            # StopLicking
            # StopLickingReturn
            # TrialEnd

            states = {k: (v + 1) for v, k in enumerate(t.state_names)}
            required_states = ('PreSamplePeriod', 'SamplePeriod',
                               'DelayPeriod', 'ResponseCue', 'StopLicking',
                               'TrialEnd')

            missing = list(k for k in required_states if k not in states)

            if len(missing):
                log.info('skipping trial {i}; missing {m}'.format(i=i,
                                                                  m=missing))
                continue

            gui = t.settings['GUI'].flatten()

            # ProtocolType - only ingest protocol >= 3
            #
            # 1 Water-Valve-Calibration 2 Licking 3 Autoassist
            # 4 No autoassist 5 DelayEnforce 6 SampleEnforce 7 Fixed
            #

            if 'ProtocolType' not in gui.dtype.names:
                log.info('skipping trial {i}; protocol undefined'.format(i=i))
                continue

            protocol_type = gui['ProtocolType'][0]
            if gui['ProtocolType'][0] < 3:
                log.info('skipping trial {i}; protocol {n} < 3'.format(
                    i=i, n=gui['ProtocolType'][0]))
                continue

            #
            # Top-level 'Trial' record
            #

            tkey = dict(skey)
            startindex = np.where(t.state_data == states['PreSamplePeriod'])[0]

            # should be only end of 1st StopLicking;
            # rest of data is irrelevant w/r/t separately ingested ephys
            endindex = np.where(t.state_data == states['StopLicking'])[0]

            log.debug('states\n' + str(states))
            log.debug('state_data\n' + str(t.state_data))
            log.debug('startindex\n' + str(startindex))
            log.debug('endendex\n' + str(endindex))

            if not (len(startindex) and len(endindex)):
                log.info('skipping trial {i}: start/end index error: {s}/{e}'.
                         format(i=i, s=str(startindex), e=str(endindex)))
                continue

            try:
                tkey['trial'] = i
                tkey['trial_uid'] = i
                tkey['start_time'] = t.state_times[startindex][0]
            except IndexError:
                log.info('skipping trial {i}: error indexing {s}/{e} into {t}'.
                         format(i=i,
                                s=str(startindex),
                                e=str(endindex),
                                t=str(t.state_times)))
                continue

            log.debug('BehaviorIngest.make(): Trial().insert1')  # TODO msg
            log.debug('tkey' + str(tkey))
            rows['trial'].append(tkey)

            #
            # Specific BehaviorTrial information for this trial
            #

            bkey = dict(tkey)
            bkey['task'] = 'audio delay'
            bkey['task_protocol'] = 1

            # determine trial instruction
            trial_instruction = 'left'

            if gui['Reversal'][0] == 1:
                if t.ttype == 1:
                    trial_instruction = 'left'
                elif t.ttype == 0:
                    trial_instruction = 'right'
            elif gui['Reversal'][0] == 2:
                if t.ttype == 1:
                    trial_instruction = 'right'
                elif t.ttype == 0:
                    trial_instruction = 'left'

            bkey['trial_instruction'] = trial_instruction

            # determine early lick
            early_lick = 'no early'

            if (protocol_type >= 5 and 'EarlyLickDelay' in states
                    and np.any(t.state_data == states['EarlyLickDelay'])):
                early_lick = 'early'
            if (protocol_type > 5 and
                ('EarlyLickSample' in states
                 and np.any(t.state_data == states['EarlyLickSample']))):
                early_lick = 'early'

            bkey['early_lick'] = early_lick

            # determine outcome
            outcome = 'ignore'

            if ('Reward' in states
                    and np.any(t.state_data == states['Reward'])):
                outcome = 'hit'
            elif ('TimeOut' in states
                  and np.any(t.state_data == states['TimeOut'])):
                outcome = 'miss'
            elif ('NoResponse' in states
                  and np.any(t.state_data == states['NoResponse'])):
                outcome = 'ignore'

            bkey['outcome'] = outcome

            # add behavior record
            log.debug('BehaviorIngest.make(): BehaviorTrial()')
            rows['behavior_trial'].append(bkey)

            #
            # Add 'protocol' note
            #

            nkey = dict(tkey)
            nkey['trial_note_type'] = 'protocol #'
            nkey['trial_note'] = str(protocol_type)

            log.debug('BehaviorIngest.make(): TrialNote().insert1')
            rows['trial_note'].append(nkey)

            #
            # Add 'autolearn' note
            #

            nkey = dict(tkey)
            nkey['trial_note_type'] = 'autolearn'
            nkey['trial_note'] = str(gui['Autolearn'][0])
            rows['trial_note'].append(nkey)

            #pdb.set_trace()
            #
            # Add 'bitcode' note
            #
            if 'randomID' in gui.dtype.names:
                nkey = dict(tkey)
                nkey['trial_note_type'] = 'bitcode'
                nkey['trial_note'] = str(gui['randomID'][0])
                rows['trial_note'].append(nkey)

            #
            # Add presample event
            #

            ekey = dict(tkey)
            sampleindex = np.where(t.state_data == states['SamplePeriod'])[0]

            ekey['trial_event_type'] = 'presample'
            ekey['trial_event_time'] = t.state_times[startindex][0]
            ekey['duration'] = (t.state_times[sampleindex[0]] -
                                t.state_times[startindex])[0]

            log.debug('BehaviorIngest.make(): presample')
            rows['trial_event'].append(ekey)

            #
            # Add 'go' event
            #

            ekey = dict(tkey)
            responseindex = np.where(t.state_data == states['ResponseCue'])[0]

            ekey['trial_event_type'] = 'go'
            ekey['trial_event_time'] = t.state_times[responseindex][0]
            ekey['duration'] = gui['AnswerPeriod'][0]

            log.debug('BehaviorIngest.make(): go')
            rows['trial_event'].append(ekey)

            #
            # Add other 'sample' events
            #

            log.debug('BehaviorIngest.make(): sample events')
            for s in sampleindex:  # in protocol > 6 ~-> n>1
                # todo: batch events
                ekey = dict(tkey)
                ekey['trial_event_type'] = 'sample'
                ekey['trial_event_time'] = t.state_times[s]
                ekey['duration'] = gui['SamplePeriod'][0]
                rows['trial_event'].append(ekey)

            #
            # Add 'delay' events
            #

            delayindex = np.where(t.state_data == states['DelayPeriod'])[0]

            log.debug('BehaviorIngest.make(): delay events')
            for d in delayindex:  # protocol > 6 ~-> n>1
                # todo: batch events
                ekey = dict(tkey)
                ekey['trial_event_type'] = 'delay'
                ekey['trial_event_time'] = t.state_times[d]
                ekey['duration'] = gui['DelayPeriod'][0]
                rows['trial_event'].append(ekey)

            #
            # Add lick events
            #

            lickleft = np.where(t.event_data == 69)[0]
            log.debug('... lickleft: {r}'.format(r=str(lickleft)))

            if len(lickleft):
                [
                    rows['action_event'].append(
                        dict(**tkey,
                             action_event_type='left lick',
                             action_event_time=t.event_times[l]))
                    for l in lickleft
                ]

            lickright = np.where(t.event_data == 70)[0]
            log.debug('... lickright: {r}'.format(r=str(lickright)))

            if len(lickright):
                [
                    rows['action_event'].append(
                        dict(**tkey,
                             action_event_type='right lick',
                             action_event_time=t.event_times[r]))
                    for r in lickright
                ]

            # end of trial loop.

        log.info('BehaviorIngest.make(): bulk insert phase')

        log.info('BehaviorIngest.make(): ... experiment.Session.Trial')
        experiment.SessionTrial().insert(rows['trial'],
                                         ignore_extra_fields=True)

        log.info('BehaviorIngest.make(): ... experiment.BehaviorTrial')
        experiment.BehaviorTrial().insert(rows['behavior_trial'],
                                          ignore_extra_fields=True)

        log.info('BehaviorIngest.make(): ... experiment.TrialNote')
        experiment.TrialNote().insert(rows['trial_note'],
                                      ignore_extra_fields=True)

        log.info('BehaviorIngest.make(): ... experiment.TrialEvent')
        experiment.TrialEvent().insert(rows['trial_event'],
                                       ignore_extra_fields=True)

        log.info('BehaviorIngest.make(): ... experiment.ActionEvent')
        experiment.ActionEvent().insert(rows['action_event'],
                                        ignore_extra_fields=True)

        log.info('BehaviorIngest.make(): saving ingest {d}'.format(d=key))
        self.insert1(key, ignore_extra_fields=True)

        BehaviorIngest.BehaviorFile().insert(
            (dict(key, behavior_file=f.split(root)[1]) for f in matches),
            ignore_extra_fields=True)
コード例 #5
0
def fix_session(session_key):
    paths = behavior_ingest.RigDataPath.fetch(as_dict=True)
    files = (behavior_ingest.BehaviorIngest *
             behavior_ingest.BehaviorIngest.BehaviorFile
             & session_key).fetch(as_dict=True, order_by='behavior_file asc')

    filelist = []
    for pf in [(p, f) for f in files for p in paths]:
        p, f = pf
        found = find_path(p['rig_data_path'], f['behavior_file'])
        if found:
            filelist.append(found)

    if len(filelist) != len(files):
        log.warning("behavior files missing in {} ({}/{}). skipping".format(
            session_key, len(filelist), len(files)))
        return False

    log.info('filelist: {}'.format(filelist))

    #
    # Prepare PhotoStim
    #
    photosti_duration = 0.5  # (s) Hard-coded here
    photostims = {
        4: {
            'photo_stim': 4,
            'photostim_device': 'OBIS470',
            'brain_location_name': 'left_alm',
            'duration': photosti_duration
        },
        5: {
            'photo_stim': 5,
            'photostim_device': 'OBIS470',
            'brain_location_name': 'right_alm',
            'duration': photosti_duration
        },
        6: {
            'photo_stim': 6,
            'photostim_device': 'OBIS470',
            'brain_location_name': 'both_alm',
            'duration': photosti_duration
        }
    }

    #
    # Extract trial data from file(s) & prepare trial loop
    #

    trials = zip()

    trial = namedtuple(  # simple structure to track per-trial vars
        'trial', ('ttype', 'stim', 'settings', 'state_times', 'state_names',
                  'state_data', 'event_data', 'event_times'))

    for f in filelist:

        if os.stat(f).st_size / 1024 < 1000:
            log.info('skipping file {f} - too small'.format(f=f))
            continue

        log.debug('loading file {}'.format(f))

        mat = spio.loadmat(f, squeeze_me=True)
        SessionData = mat['SessionData'].flatten()

        AllTrialTypes = SessionData['TrialTypes'][0]
        AllTrialSettings = SessionData['TrialSettings'][0]

        RawData = SessionData['RawData'][0].flatten()
        AllStateNames = RawData['OriginalStateNamesByNumber'][0]
        AllStateData = RawData['OriginalStateData'][0]
        AllEventData = RawData['OriginalEventData'][0]
        AllStateTimestamps = RawData['OriginalStateTimestamps'][0]
        AllEventTimestamps = RawData['OriginalEventTimestamps'][0]

        # verify trial-related data arrays are all same length
        assert (all(
            (x.shape[0] == AllStateTimestamps.shape[0]
             for x in (AllTrialTypes, AllTrialSettings, AllStateNames,
                       AllStateData, AllEventData, AllEventTimestamps))))

        if 'StimTrials' in SessionData.dtype.fields:
            log.debug('StimTrials detected in session - will include')
            AllStimTrials = SessionData['StimTrials'][0]
            assert (AllStimTrials.shape[0] == AllStateTimestamps.shape[0])
        else:
            log.debug('StimTrials not detected in session - will skip')
            AllStimTrials = np.array(
                [None for i in enumerate(range(AllStateTimestamps.shape[0]))])

        z = zip(AllTrialTypes, AllStimTrials, AllTrialSettings,
                AllStateTimestamps, AllStateNames, AllStateData, AllEventData,
                AllEventTimestamps)

        trials = chain(trials, z)  # concatenate the files

    trials = list(trials)

    # all files were internally invalid or size < 100k
    if not trials:
        log.warning('skipping ., no valid files')
        return False

    key = session_key
    skey = (experiment.Session & key).fetch1()

    #
    # Actually load the per-trial data
    #
    log.info('BehaviorIngest.make(): trial parsing phase')

    # lists of various records for batch-insert
    rows = {
        k: list()
        for k in ('trial', 'behavior_trial', 'trial_note', 'trial_event',
                  'corrected_trial_event', 'action_event', 'photostim',
                  'photostim_location', 'photostim_trial',
                  'photostim_trial_event')
    }

    i = -1
    for t in trials:

        #
        # Misc
        #

        t = trial(*t)  # convert list of items to a 'trial' structure
        i += 1  # increment trial counter

        log.debug('BehaviorIngest.make(): parsing trial {i}'.format(i=i))

        # covert state data names into a lookup dictionary
        #
        # names (seem to be? are?):
        #
        # Trigtrialstart
        # PreSamplePeriod
        # SamplePeriod
        # DelayPeriod
        # EarlyLickDelay
        # EarlyLickSample
        # ResponseCue
        # GiveLeftDrop
        # GiveRightDrop
        # GiveLeftDropShort
        # GiveRightDropShort
        # AnswerPeriod
        # Reward
        # RewardConsumption
        # NoResponse
        # TimeOut
        # StopLicking
        # StopLickingReturn
        # TrialEnd

        states = {k: (v + 1) for v, k in enumerate(t.state_names)}
        required_states = ('PreSamplePeriod', 'SamplePeriod', 'DelayPeriod',
                           'ResponseCue', 'StopLicking', 'TrialEnd')

        missing = list(k for k in required_states if k not in states)

        if len(missing):
            log.warning('skipping trial {i}; missing {m}'.format(i=i,
                                                                 m=missing))
            continue

        gui = t.settings['GUI'].flatten()

        # ProtocolType - only ingest protocol >= 3
        #
        # 1 Water-Valve-Calibration 2 Licking 3 Autoassist
        # 4 No autoassist 5 DelayEnforce 6 SampleEnforce 7 Fixed
        #

        if 'ProtocolType' not in gui.dtype.names:
            log.warning('skipping trial {i}; protocol undefined'.format(i=i))
            continue

        protocol_type = gui['ProtocolType'][0]
        if gui['ProtocolType'][0] < 3:
            log.warning('skipping trial {i}; protocol {n} < 3'.format(
                i=i, n=gui['ProtocolType'][0]))
            continue

        #
        # Top-level 'Trial' record
        #

        tkey = dict(skey)
        startindex = np.where(t.state_data == states['PreSamplePeriod'])[0]

        # should be only end of 1st StopLicking;
        # rest of data is irrelevant w/r/t separately ingested ephys
        endindex = np.where(t.state_data == states['StopLicking'])[0]

        log.debug('states\n' + str(states))
        log.debug('state_data\n' + str(t.state_data))
        log.debug('startindex\n' + str(startindex))
        log.debug('endindex\n' + str(endindex))

        if not (len(startindex) and len(endindex)):
            log.warning(
                'skipping trial {i}: start/end index error: {s}/{e}'.format(
                    i=i, s=str(startindex), e=str(endindex)))
            continue

        try:
            tkey['trial'] = i
            tkey[
                'trial_uid'] = i  # Arseny has unique id to identify some trials
            tkey['start_time'] = t.state_times[startindex][0]
            tkey['stop_time'] = t.state_times[endindex][0]
        except IndexError:
            log.warning(
                'skipping trial {i}: error indexing {s}/{e} into {t}'.format(
                    i=i,
                    s=str(startindex),
                    e=str(endindex),
                    t=str(t.state_times)))
            continue

        log.debug('BehaviorIngest.make(): Trial().insert1')  # TODO msg
        log.debug('tkey' + str(tkey))
        rows['trial'].append(tkey)

        #
        # Specific BehaviorTrial information for this trial
        #

        bkey = dict(tkey)
        bkey['task'] = 'audio delay'  # hard-coded here
        bkey['task_protocol'] = 1  # hard-coded here

        # determine trial instruction
        trial_instruction = 'left'  # hard-coded here

        if gui['Reversal'][0] == 1:
            if t.ttype == 1:
                trial_instruction = 'left'
            elif t.ttype == 0:
                trial_instruction = 'right'
        elif gui['Reversal'][0] == 2:
            if t.ttype == 1:
                trial_instruction = 'right'
            elif t.ttype == 0:
                trial_instruction = 'left'

        bkey['trial_instruction'] = trial_instruction

        # determine early lick
        early_lick = 'no early'

        if (protocol_type >= 5 and 'EarlyLickDelay' in states
                and np.any(t.state_data == states['EarlyLickDelay'])):
            early_lick = 'early'
        if (protocol_type > 5
                and ('EarlyLickSample' in states
                     and np.any(t.state_data == states['EarlyLickSample']))):
            early_lick = 'early'

        bkey['early_lick'] = early_lick

        # determine outcome
        outcome = 'ignore'

        if ('Reward' in states and np.any(t.state_data == states['Reward'])):
            outcome = 'hit'
        elif ('TimeOut' in states
              and np.any(t.state_data == states['TimeOut'])):
            outcome = 'miss'
        elif ('NoResponse' in states
              and np.any(t.state_data == states['NoResponse'])):
            outcome = 'ignore'

        bkey['outcome'] = outcome
        rows['behavior_trial'].append(bkey)

        #
        # Add 'protocol' note
        #
        nkey = dict(tkey)
        nkey['trial_note_type'] = 'protocol #'
        nkey['trial_note'] = str(protocol_type)
        rows['trial_note'].append(nkey)

        #
        # Add 'autolearn' note
        #
        nkey = dict(tkey)
        nkey['trial_note_type'] = 'autolearn'
        nkey['trial_note'] = str(gui['Autolearn'][0])
        rows['trial_note'].append(nkey)

        #
        # Add 'bitcode' note
        #
        if 'randomID' in gui.dtype.names:
            nkey = dict(tkey)
            nkey['trial_note_type'] = 'bitcode'
            nkey['trial_note'] = str(gui['randomID'][0])
            rows['trial_note'].append(nkey)

        #
        # Add presample event
        #
        log.debug('BehaviorIngest.make(): presample')

        ekey = dict(tkey)
        sampleindex = np.where(t.state_data == states['SamplePeriod'])[0]

        ekey['trial_event_id'] = len(rows['trial_event'])
        ekey['trial_event_type'] = 'presample'
        ekey['trial_event_time'] = t.state_times[startindex][0]
        ekey['duration'] = (t.state_times[sampleindex[0]] -
                            t.state_times[startindex])[0]

        if math.isnan(ekey['duration']):
            log.debug('BehaviorIngest.make(): fixing presample duration')
            ekey['duration'] = 0.0  # FIXDUR: lookup from previous trial

        rows['trial_event'].append(ekey)

        #
        # Add other 'sample' events
        #

        log.debug('BehaviorIngest.make(): sample events')

        last_dur = None

        for s in sampleindex:  # in protocol > 6 ~-> n>1
            # todo: batch events
            ekey = dict(tkey)
            ekey['trial_event_id'] = len(rows['trial_event'])
            ekey['trial_event_type'] = 'sample'
            ekey['trial_event_time'] = t.state_times[s]
            ekey['duration'] = gui['SamplePeriod'][0]

            if math.isnan(ekey['duration']) and last_dur is None:
                log.warning('... trial {} bad duration, no last_edur'.format(
                    i, last_dur))
                ekey['duration'] = 0.0  # FIXDUR: cross-trial check
                rows['corrected_trial_event'].append(ekey)

            elif math.isnan(ekey['duration']) and last_dur is not None:
                log.warning('... trial {} duration using last_edur {}'.format(
                    i, last_dur))
                ekey['duration'] = last_dur
                rows['corrected_trial_event'].append(ekey)

            else:
                last_dur = ekey['duration']  # only track 'good' values.

            rows['trial_event'].append(ekey)

        #
        # Add 'delay' events
        #

        log.debug('BehaviorIngest.make(): delay events')

        last_dur = None
        delayindex = np.where(t.state_data == states['DelayPeriod'])[0]

        for d in delayindex:  # protocol > 6 ~-> n>1
            ekey = dict(tkey)
            ekey['trial_event_id'] = len(rows['trial_event'])
            ekey['trial_event_type'] = 'delay'
            ekey['trial_event_time'] = t.state_times[d]
            ekey['duration'] = gui['DelayPeriod'][0]

            if math.isnan(ekey['duration']) and last_dur is None:
                log.warning('... {} bad duration, no last_edur'.format(
                    i, last_dur))
                ekey['duration'] = 0.0  # FIXDUR: cross-trial check
                rows['corrected_trial_event'].append(ekey)

            elif math.isnan(ekey['duration']) and last_dur is not None:
                log.warning('... {} duration using last_edur {}'.format(
                    i, last_dur))
                ekey['duration'] = last_dur
                rows['corrected_trial_event'].append(ekey)

            else:
                last_dur = ekey['duration']  # only track 'good' values.

            log.debug('delay event duration: {}'.format(ekey['duration']))
            rows['trial_event'].append(ekey)

        #
        # Add 'go' event
        #
        log.debug('BehaviorIngest.make(): go')

        ekey = dict(tkey)
        responseindex = np.where(t.state_data == states['ResponseCue'])[0]

        ekey['trial_event_id'] = len(rows['trial_event'])
        ekey['trial_event_type'] = 'go'
        ekey['trial_event_time'] = t.state_times[responseindex][0]
        ekey['duration'] = gui['AnswerPeriod'][0]

        if math.isnan(ekey['duration']):
            log.debug('BehaviorIngest.make(): fixing go duration')
            ekey['duration'] = 0.0  # FIXDUR: lookup from previous trials
            rows['corrected_trial_event'].append(ekey)

        rows['trial_event'].append(ekey)

        #
        # Add 'trialEnd' events
        #

        log.debug('BehaviorIngest.make(): trialend events')

        last_dur = None
        trialendindex = np.where(t.state_data == states['TrialEnd'])[0]

        ekey = dict(tkey)
        ekey['trial_event_id'] = len(rows['trial_event'])
        ekey['trial_event_type'] = 'trialend'
        ekey['trial_event_time'] = t.state_times[trialendindex][0]
        ekey['duration'] = 0.0

        rows['trial_event'].append(ekey)

        #
        # Add lick events
        #

        lickleft = np.where(t.event_data == 69)[0]
        log.debug('... lickleft: {r}'.format(r=str(lickleft)))

        action_event_count = len(rows['action_event'])
        if len(lickleft):
            [
                rows['action_event'].append(
                    dict(tkey,
                         action_event_id=action_event_count + idx,
                         action_event_type='left lick',
                         action_event_time=t.event_times[l]))
                for idx, l in enumerate(lickleft)
            ]

        lickright = np.where(t.event_data == 71)[0]
        log.debug('... lickright: {r}'.format(r=str(lickright)))

        action_event_count = len(rows['action_event'])
        if len(lickright):
            [
                rows['action_event'].append(
                    dict(tkey,
                         action_event_id=action_event_count + idx,
                         action_event_type='right lick',
                         action_event_time=t.event_times[r]))
                for idx, r in enumerate(lickright)
            ]

        # Photostim Events
        #
        # TODO:
        #
        # - base stimulation parameters:
        #
        #   - should be loaded elsewhere - where
        #   - actual ccf locations - cannot be known apriori apparently?
        #   - Photostim.Profile: what is? fix/add
        #
        # - stim data
        #
        #   - how retrieve power from file (didn't see) or should
        #     be statically coded here?
        #   - how encode stim type 6?
        #     - we have hemisphere as boolean or
        #     - but adding an event 4 and event 5 means querying
        #       is less straightforwrard (e.g. sessions with 5 & 6)

        if t.stim:
            log.debug('BehaviorIngest.make(): t.stim == {}'.format(t.stim))
            rows['photostim_trial'].append(tkey)
            delay_period_idx = np.where(
                t.state_data == states['DelayPeriod'])[0][0]
            rows['photostim_trial_event'].append(
                dict(tkey,
                     **photostims[t.stim],
                     photostim_event_id=len(rows['photostim_trial_event']),
                     photostim_event_time=t.state_times[delay_period_idx],
                     power=5.5))

        # end of trial loop.

    log.info('BehaviorIngest.make(): ... experiment.TrialEvent')

    fix_events = rows['trial_event']

    ref_events = (experiment.TrialEvent() & skey).fetch(
        order_by='trial, trial_event_id', as_dict=True)

    if False:
        for e in ref_events:
            log.debug('ref_events: t: {}, e: {}, event_type: {}'.format(
                e['trial'], e['trial_event_id'], e['trial_event_type']))
        for e in fix_events:
            log.debug('fix_events: t: {}, e: {}, type: {}'.format(
                e['trial'], e['trial_event_id'], e['trial_event_type']))

    log.info('deleting old events')

    with dj.config(safemode=False):

        log.info('... TrialEvent')
        (experiment.TrialEvent() & session_key).delete()

        log.info('... CorrectedTrialEvents')
        (behavior_ingest.BehaviorIngest.CorrectedTrialEvents()
         & session_key).delete_quick()

    log.info('adding new records')

    log.info('... experiment.TrialEvent')
    experiment.TrialEvent().insert(rows['trial_event'],
                                   ignore_extra_fields=True,
                                   allow_direct_insert=True,
                                   skip_duplicates=True)

    log.info('... CorrectedTrialEvents')
    behavior_ingest.BehaviorIngest.CorrectedTrialEvents().insert(
        rows['corrected_trial_event'],
        ignore_extra_fields=True,
        allow_direct_insert=True)

    return True
コード例 #6
0
def extract_trials(plottype='2lickport',
                   wr_name='FOR01',
                   sessions=(5, 11),
                   show_bias_check_trials=True,
                   kernel=np.ones(10) / 10,
                   filters=None,
                   local_matching={'calculate_local_matching': False}):

    #%%
    # =============================================================================
    #     plottype = '2lickport'
    #     wr_name = 'FOR11'
    #     sessions = (25,46)
    #     show_bias_check_trials = False
    #     kernel = np.ones(20)/20
    #     filters = {'ignore_rate_max':40}
    #     local_matching = {'calculate_local_matching': True,
    #                      'sliding_window':50,
    #                      'matching_window':500,
    #                      'matching_step':100}
    # =============================================================================

    movingwindow = local_matching['sliding_window']
    fit_window = local_matching['matching_window']
    fit_step = local_matching['matching_step']

    subject_id = (lab.WaterRestriction()
                  & 'water_restriction_number = "{}"'.format(wr_name)
                  ).fetch1('subject_id')

    df_behaviortrial = pd.DataFrame(np.asarray(
        (experiment.BehaviorTrial() * experiment.SessionTrial() *
         experiment.TrialEvent() * experiment.SessionBlock() *
         behavior_foraging.TrialReactionTime
         & 'subject_id = {}'.format(subject_id)
         & 'session >= {}'.format(sessions[0])
         & 'session <= {}'.format(sessions[1])
         & 'trial_event_type = "go"').fetch(
             'session', 'trial', 'early_lick', 'trial_start_time',
             'reaction_time', 'p_reward_left', 'p_reward_right',
             'p_reward_middle', 'trial_event_time', 'trial_choice',
             'outcome')).T,
                                    columns=[
                                        'session', 'trial', 'early_lick',
                                        'trial_start_time', 'reaction_time',
                                        'p_reward_left', 'p_reward_right',
                                        'p_reward_middle', 'trial_event_time',
                                        'trial_choice', 'outcome'
                                    ])

    unique_sessions = df_behaviortrial['session'].unique()
    df_behaviortrial['iti'] = np.nan
    df_behaviortrial['delay'] = np.nan
    df_behaviortrial['early_count'] = 0
    df_behaviortrial.loc[df_behaviortrial['early_lick'] == 'early',
                         'early_count'] = 1
    df_behaviortrial['ignore_rate'] = np.nan
    df_behaviortrial['reaction_time_smoothed'] = np.nan
    if type(filters) == dict:
        df_behaviortrial['keep_trial'] = 1
    for session in unique_sessions:
        total_trials_so_far = (
            behavior_foraging.SessionStats()
            & 'subject_id = {}'.format(subject_id)
            & 'session < {}'.format(session)).fetch('session_total_trial_num')
        bias_check_trials_now = (behavior_foraging.SessionStats()
                                 & 'subject_id = {}'.format(subject_id)
                                 & 'session = {}'.format(session)
                                 ).fetch1('session_bias_check_trial_num')
        total_trials_so_far = sum(total_trials_so_far)
        gotime = df_behaviortrial.loc[df_behaviortrial['session'] == session,
                                      'trial_event_time']
        trialtime = df_behaviortrial.loc[df_behaviortrial['session'] ==
                                         session, 'trial_start_time']
        itis = np.concatenate([[np.nan],
                               np.diff(np.asarray(trialtime + gotime, float))])
        df_behaviortrial.loc[df_behaviortrial['session'] == session,
                             'iti'] = itis
        df_behaviortrial.loc[df_behaviortrial['session'] == session,
                             'delay'] = np.asarray(gotime, float)

        df_behaviortrial.loc[df_behaviortrial['session'] == session,
                             'ignore_rate'] = np.convolve(
                                 df_behaviortrial.loc[
                                     df_behaviortrial['session'] == session,
                                     'outcome'] == 'ignore', kernel, 'same')
        reaction_time_interpolated = np.asarray(
            pd.DataFrame(
                np.asarray(
                    df_behaviortrial.loc[df_behaviortrial['session'] ==
                                         session, 'reaction_time'].values,
                    float)).interpolate().values.ravel().tolist()) * 1000
        df_behaviortrial.loc[df_behaviortrial['session'] == session,
                             'reaction_time_smoothed'] = np.convolve(
                                 reaction_time_interpolated, kernel, 'same')
        df_behaviortrial.loc[df_behaviortrial['session'] == session,
                             'trial'] += total_trials_so_far

        if type(filters) == dict:
            max_idx = (
                df_behaviortrial.loc[df_behaviortrial['session'] == session,
                                     'ignore_rate'] >
                filters['ignore_rate_max'] / 100).idxmax()

            session_first_trial_idx = (
                df_behaviortrial['session'] == session).idxmax()
            #print(max_idx)
            if max_idx > session_first_trial_idx or df_behaviortrial[
                    'ignore_rate'][session_first_trial_idx] > filters[
                        'ignore_rate_max'] / 100:
                df_behaviortrial.loc[df_behaviortrial.index.isin(
                    np.arange(max_idx, len(df_behaviortrial))) &
                                     (df_behaviortrial['session'] == session),
                                     'keep_trial'] = 0

#%
    if type(filters) == dict:
        trialstokeep = df_behaviortrial['keep_trial'] == 1
        df_behaviortrial = df_behaviortrial[trialstokeep]
        df_behaviortrial = df_behaviortrial.reset_index(drop=True)

    if not show_bias_check_trials:
        realtraining = (df_behaviortrial['p_reward_left'] <
                        1) & (df_behaviortrial['p_reward_right'] < 1) & (
                            (df_behaviortrial['p_reward_middle'] < 1)
                            | df_behaviortrial['p_reward_middle'].isnull())
        df_behaviortrial = df_behaviortrial[realtraining]
        df_behaviortrial = df_behaviortrial.reset_index(drop=True)

    #% calculating local matching, bias, reward rate

    kernel = np.ones(movingwindow) / movingwindow
    p1 = np.asarray(
        np.max([
            df_behaviortrial['p_reward_right'],
            df_behaviortrial['p_reward_left']
        ], 0), float)
    p0 = np.asarray(
        np.min([
            df_behaviortrial['p_reward_right'],
            df_behaviortrial['p_reward_left']
        ], 0), float)
    m_star_greedy = np.floor(np.log(1 - p1) / np.log(1 - p0))
    p_star_greedy = p1 + (1 - (1 - p0)**
                          (m_star_greedy + 1) - p1**2) / (m_star_greedy + 1)
    local_reward_rate = np.convolve(df_behaviortrial['outcome'] == 'hit',
                                    kernel, 'same')
    max_reward_rate = np.convolve(p_star_greedy, kernel, 'same')
    local_efficiency = local_reward_rate / max_reward_rate
    choice_right = np.asarray(df_behaviortrial['trial_choice'] == 'right')
    choice_left = np.asarray(df_behaviortrial['trial_choice'] == 'left')
    choice_middle = np.asarray(df_behaviortrial['trial_choice'] == 'middle')

    reward_rate_right = np.asarray(
        (df_behaviortrial['trial_choice'] == 'right')
        & (df_behaviortrial['outcome'] == 'hit'))
    reward_rate_left = np.asarray((df_behaviortrial['trial_choice'] == 'left')
                                  & (df_behaviortrial['outcome'] == 'hit'))
    reward_rate_middle = np.asarray(
        (df_behaviortrial['trial_choice'] == 'middle')
        & (df_behaviortrial['outcome'] == 'hit'))

    # =============================================================================
    #     choice_fraction_right = np.convolve(choice_right,kernel,'same')/np.convolve(choice_right+choice_left+choice_middle,kernel,'same')
    #     reward_fraction_right = np.convolve(reward_rate_right,kernel,'same')/local_reward_rate
    # =============================================================================
    choice_rate_right = np.convolve(
        choice_right, kernel, 'same') / np.convolve(
            choice_left + choice_middle, kernel, 'same')
    reward_rate_right = np.convolve(
        reward_rate_right, kernel, 'same') / np.convolve(
            reward_rate_left + reward_rate_middle, kernel, 'same')
    slopes = list()
    intercepts = list()
    trial_number = list()
    for center_trial in np.arange(np.round(fit_window / 2),
                                  len(df_behaviortrial), fit_step):
        #%
        reward_rates_now = reward_rate_right[
            int(np.round(center_trial - fit_window /
                         2)):int(np.round(center_trial + fit_window / 2))]
        choice_rates_now = choice_rate_right[
            int(np.round(center_trial - fit_window /
                         2)):int(np.round(center_trial + fit_window / 2))]
        todel = (reward_rates_now == 0) | (choice_rates_now == 0)
        reward_rates_now = reward_rates_now[~todel]
        choice_rates_now = choice_rates_now[~todel]
        try:
            slope_now, intercept_now = np.polyfit(np.log2(reward_rates_now),
                                                  np.log2(choice_rates_now), 1)
            slopes.append(slope_now)
            intercepts.append(intercept_now)
            trial_number.append(center_trial)
        except:
            pass

    df_behaviortrial['local_efficiency'] = local_efficiency
    df_behaviortrial['local_matching_slope'] = np.nan
    df_behaviortrial.loc[trial_number, 'local_matching_slope'] = slopes
    df_behaviortrial['local_matching_bias'] = np.nan
    df_behaviortrial.loc[trial_number, 'local_matching_bias'] = intercepts
    #%%
    return df_behaviortrial
コード例 #7
0
def populatebehavior_core(IDs = None):
    if IDs:
        print('subject started:')
        print(IDs.keys())
        print(IDs.values())
        
    rigpath_1 = 'E:/Projects/Ablation/datajoint/Behavior'
    
    #df_surgery = pd.read_csv(dj.config['locations.metadata']+'Surgery.csv')
    if IDs == None:
        IDs = {k: v for k, v in zip(*lab.WaterRestriction().fetch('water_restriction_number', 'subject_id'))}   

    for subject_now,subject_id_now in zip(IDs.keys(),IDs.values()): # iterating over subjects
        print('subject: ',subject_now)
    # =============================================================================
    #         if drop_last_session_for_mice_in_training:
    #             delete_last_session_before_upload = True
    #         else:
    #             delete_last_session_before_upload = False
    #         #df_wr = online_notebook.fetch_water_restriction_metadata(subject_now)
    # =============================================================================
        try:
            df_wr = pd.read_csv(dj.config['locations.metadata_behavior']+subject_now+'.csv')
        except:
            print(subject_now + ' has no metadata available')
            df_wr = pd.DataFrame()
        for df_wr_row in df_wr.iterrows():
            date_now = df_wr_row[1].Date.replace('-','')
            print('subject: ',subject_now,'  date: ',date_now)
            session_date = datetime(int(date_now[0:4]),int(date_now[4:6]),int(date_now[6:8]))
            if len(experiment.Session() & 'subject_id = "'+str(subject_id_now)+'"' & 'session_date > "'+str(session_date)+'"') != 0: # if it is not the last
                print('session already imported, skipping: ' + str(session_date))
                dotheupload = False
            elif len(experiment.Session() & 'subject_id = "'+str(subject_id_now)+'"' & 'session_date = "'+str(session_date)+'"') != 0: # if it is the last
                dotheupload = False
            else: # reuploading new session that is not present on the server
                dotheupload = True
                
            # if dotheupload is True, meaning that there are new mat file hasn't been uploaded
            # => needs to find which mat file hasn't been uploaded
            
            if dotheupload:
                found = set()
                rigpath_2 = subject_now
                rigpath_3 = rigpath_1 + '/' + rigpath_2
                rigpath = pathlib.Path(rigpath_3)
                
                def buildrec(rigpath, root, f):
                    try:
                        fullpath = pathlib.Path(root, f)
                        subpath = fullpath.relative_to(rigpath)
                        fsplit = subpath.stem.split('_')
                        h2o = fsplit[0]
                        ymd = fsplit[-2:-1][0]
                        animal = IDs[h2o]
                        if ymd == date_now:
                            return {
                                    'subject_id': animal,
                                    'session_date': date(int(ymd[0:4]), int(ymd[4:6]), int(ymd[6:8])),
                                    'rig_data_path': rigpath.as_posix(),
                                    'subpath': subpath.as_posix(),
                                    }
                    except:
                        pass
                for root, dirs, files in os.walk(rigpath):
                    for f in files:
                        r = buildrec(rigpath, root, f)
                        if r:
                            found.add(r['subpath'])
                            file = r
                
                # now start insert data
            
                path = pathlib.Path(file['rig_data_path'], file['subpath'])
                mat = spio.loadmat(path, squeeze_me=True)
                SessionData = mat['SessionData'].flatten()
                            
                # session record key
                skey = {}
                skey['subject_id'] = file['subject_id']
                skey['session_date'] = file['session_date']
                skey['username'] = '******'
                #skey['rig'] = key['rig']
            
                trial = namedtuple(  # simple structure to track per-trial vars
                        'trial', ('ttype', 'settings', 'state_times',
                                  'state_names', 'state_data', 'event_data',
                                  'event_times', 'trial_start'))
            
                # parse session datetime
                session_datetime_str = str('').join((str(SessionData['Info'][0]['SessionDate']),' ', str(SessionData['Info'][0]['SessionStartTime_UTC'])))
                session_datetime = datetime.strptime(session_datetime_str, '%d-%b-%Y %H:%M:%S')
            
                AllTrialTypes = SessionData['TrialTypes'][0]
                AllTrialSettings = SessionData['TrialSettings'][0]
                AllTrialStarts = SessionData['TrialStartTimestamp'][0]
                AllTrialStarts = AllTrialStarts - AllTrialStarts[0]
            
                RawData = SessionData['RawData'][0].flatten()
                AllStateNames = RawData['OriginalStateNamesByNumber'][0]
                AllStateData = RawData['OriginalStateData'][0]
                AllEventData = RawData['OriginalEventData'][0]
                AllStateTimestamps = RawData['OriginalStateTimestamps'][0]
                AllEventTimestamps = RawData['OriginalEventTimestamps'][0]
            
                trials = list(zip(AllTrialTypes, AllTrialSettings,
                                  AllStateTimestamps, AllStateNames, AllStateData,
                                  AllEventData, AllEventTimestamps, AllTrialStarts))
                
                if not trials:
                    log.warning('skipping date {d}, no valid files'.format(d=date))
                    return    
                #
                # Trial data seems valid; synthesize session id & add session record
                # XXX: note - later breaks can result in Sessions without valid trials
                #
            
                assert skey['session_date'] == session_datetime.date()
                
                skey['session_date'] = session_datetime.date()
                #skey['session_time'] = session_datetime.time()
            
                if len(experiment.Session() & 'subject_id = "'+str(file['subject_id'])+'"' & 'session_date = "'+str(file['session_date'])+'"') == 0:
                    if len(experiment.Session() & 'subject_id = "'+str(file['subject_id'])+'"') == 0:
                        skey['session'] = 1
                    else:
                        skey['session'] = len((experiment.Session() & 'subject_id = "'+str(file['subject_id'])+'"').fetch()['session']) + 1
            
                #
                # Actually load the per-trial data
                #
                log.info('BehaviorIngest.make(): trial parsing phase')

                # lists of various records for batch-insert
                rows = {k: list() for k in ('trial', 'behavior_trial', 'trial_note',
                                        'trial_event', 'corrected_trial_event',
                                        'action_event')} #, 'photostim',
                                    #'photostim_location', 'photostim_trial',
                                    #'photostim_trial_event')}

                i = 0  # trial numbering starts at 1
                for t in trials:
                    t = trial(*t)  # convert list of items to a 'trial' structure
                    i += 1  # increment trial counter

                    log.debug('BehaviorIngest.make(): parsing trial {i}'.format(i=i))

                    states = {k: (v+1) for v, k in enumerate(t.state_names)}
                    required_states = ('PreSamplePeriod', 'SamplePeriod',
                                       'DelayPeriod', 'ResponseCue', 'StopLicking',
                                       'TrialEnd')
                
                    missing = list(k for k in required_states if k not in states)
                    if len(missing) and missing =='PreSamplePeriod':
                        log.warning('skipping trial {i}; missing {m}'.format(i=i, m=missing))
                        continue

                    gui = t.settings['GUI'].flatten()
                    if len(experiment.Session() & 'subject_id = "'+str(file['subject_id'])+'"' & 'session_date = "'+str(file['session_date'])+'"') == 0:
                        if len(experiment.Session() & 'subject_id = "'+str(file['subject_id'])+'"') == 0:
                            skey['session'] = 1
                        else:
                            skey['session'] = len((experiment.Session() & 'subject_id = "'+str(file['subject_id'])+'"').fetch()['session']) + 1
                
                    #
                    # Top-level 'Trial' record
                    #
                    protocol_type = gui['ProtocolType'][0]
                    tkey = dict(skey)
                    has_presample = 1
                    try:
                        startindex = np.where(t.state_data == states['PreSamplePeriod'])[0]
                        has_presample = 1
                    except:
                        startindex = np.where(t.state_data == states['SamplePeriod'])[0]
                        has_presample = 0
                
                    # should be only end of 1st StopLicking;
                    # rest of data is irrelevant w/r/t separately ingested ephys
                    endindex = np.where(t.state_data == states['StopLicking'])[0]
                    log.debug('states\n' + str(states))
                    log.debug('state_data\n' + str(t.state_data))
                    log.debug('startindex\n' + str(startindex))
                    log.debug('endindex\n' + str(endindex))
                
                    if not(len(startindex) and len(endindex)):
                        log.warning('skipping {}: start/end mismatch: {}/{}'.format(i, str(startindex), str(endindex)))
                        continue
                    
                    try:
                        tkey['trial'] = i
                        tkey['trial_uid'] = i
                        tkey['trial_start_time'] = t.trial_start
                        tkey['trial_stop_time'] = t.trial_start + t.state_times[endindex][0]
                    except IndexError:
                        log.warning('skipping {}: IndexError: {}/{} -> {}'.format(i, str(startindex), str(endindex), str(t.state_times)))
                        continue
                    
                    log.debug('tkey' + str(tkey))
                    rows['trial'].append(tkey)
                
                    #
                    # Specific BehaviorTrial information for this trial
                    #                              
                    
                    bkey = dict(tkey)
                    bkey['task'] = 'audio delay'  # hard-coded here
                    bkey['task_protocol'] = 1     # hard-coded here
                
                    # determine trial instruction
                    trial_instruction = 'left'    # hard-coded here

                    if gui['Reversal'][0] == 1:
                        if t.ttype == 1:
                            trial_instruction = 'left'
                        elif t.ttype == 0:
                            trial_instruction = 'right'
                        elif t.ttype == 2:
                            trial_instruction = 'catch_right_autowater'
                        elif t.ttype == 3:
                            trial_instruction = 'catch_left_autowater'
                        elif t.ttype == 4:
                            trial_instruction = 'catch_right_noDelay'
                        elif t.ttype == 5:
                            trial_instruction = 'catch_left_noDelay'    
                    elif gui['Reversal'][0] == 2:
                        if t.ttype == 1:
                            trial_instruction = 'right'
                        elif t.ttype == 0:
                            trial_instruction = 'left'
                        elif t.ttype == 2:
                            trial_instruction = 'catch_left_autowater'
                        elif t.ttype == 3:
                            trial_instruction = 'catch_right_autowater'
                        elif t.ttype == 4:
                            trial_instruction = 'catch_left_noDelay'
                        elif t.ttype == 5:
                            trial_instruction = 'catch_right_noDelay'
                
                    bkey['trial_instruction'] = trial_instruction
                    # determine early lick
                    early_lick = 'no early'
                    
                    if (protocol_type >= 5 and 'EarlyLickDelay' in states and np.any(t.state_data == states['EarlyLickDelay'])):
                        early_lick = 'early'
                    if (protocol_type >= 5 and ('EarlyLickSample' in states and np.any(t.state_data == states['EarlyLickSample']))):
                        early_lick = 'early'
                        
                    bkey['early_lick'] = early_lick
                
                    # determine outcome
                    outcome = 'ignore'
                    if ('Reward' in states and np.any(t.state_data == states['Reward'])):
                        outcome = 'hit'
                    elif ('TimeOut' in states and np.any(t.state_data == states['TimeOut'])):
                        outcome = 'miss'
                    elif ('NoResponse' in states and np.any(t.state_data == states['NoResponse'])):
                        outcome = 'ignore'    
                    bkey['outcome'] = outcome
                    rows['behavior_trial'].append(bkey)
                    
                    #
                    # Add 'protocol' note
                    #
                    nkey = dict(tkey)
                    nkey['trial_note_type'] = 'protocol #'
                    nkey['trial_note'] = str(protocol_type)
                    rows['trial_note'].append(nkey)

                    #
                    # Add 'autolearn' note
                    #
                    nkey = dict(tkey)
                    nkey['trial_note_type'] = 'autolearn'
                    nkey['trial_note'] = str(gui['Autolearn'][0])
                    rows['trial_note'].append(nkey)
                    
                    #
                    # Add 'bitcode' note
                    #
                    if 'randomID' in gui.dtype.names:
                        nkey = dict(tkey)
                        nkey['trial_note_type'] = 'bitcode'
                        nkey['trial_note'] = str(gui['randomID'][0])
                        rows['trial_note'].append(nkey)
               
                
                    #
                    # Add presample event
                    #
                    sampleindex = np.where(t.state_data == states['SamplePeriod'])[0]
                    
                    if has_presample == 1:
                        log.debug('BehaviorIngest.make(): presample')
                        ekey = dict(tkey)                    
    
                        ekey['trial_event_id'] = len(rows['trial_event'])
                        ekey['trial_event_type'] = 'presample'
                        ekey['trial_event_time'] = t.state_times[startindex][0]
                        ekey['duration'] = (t.state_times[sampleindex[0]]- t.state_times[startindex])[0]
    
                        if math.isnan(ekey['duration']):
                            log.debug('BehaviorIngest.make(): fixing presample duration')
                            ekey['duration'] = 0.0  # FIXDUR: lookup from previous trial
    
                        rows['trial_event'].append(ekey)
                
                    #
                    # Add other 'sample' events
                    #
    
                    log.debug('BehaviorIngest.make(): sample events')
    
                    last_dur = None
    
                    for s in sampleindex:  # in protocol > 6 ~-> n>1
                        # todo: batch events
                        ekey = dict(tkey)
                        ekey['trial_event_id'] = len(rows['trial_event'])
                        ekey['trial_event_type'] = 'sample'
                        ekey['trial_event_time'] = t.state_times[s]
                        ekey['duration'] = gui['SamplePeriod'][0]
    
                        if math.isnan(ekey['duration']) and last_dur is None:
                            log.warning('... trial {} bad duration, no last_edur'.format(i, last_dur))
                            ekey['duration'] = 0.0  # FIXDUR: cross-trial check
                            rows['corrected_trial_event'].append(ekey)
    
                        elif math.isnan(ekey['duration']) and last_dur is not None:
                            log.warning('... trial {} duration using last_edur {}'.format(i, last_dur))
                            ekey['duration'] = last_dur
                            rows['corrected_trial_event'].append(ekey)
    
                        else:
                            last_dur = ekey['duration']  # only track 'good' values.
    
                        rows['trial_event'].append(ekey)
                
                    #
                    # Add 'delay' events
                    #
    
                    log.debug('BehaviorIngest.make(): delay events')
    
                    last_dur = None
                    delayindex = np.where(t.state_data == states['DelayPeriod'])[0]
    
                    for d in delayindex:  # protocol > 6 ~-> n>1
                        ekey = dict(tkey)
                        ekey['trial_event_id'] = len(rows['trial_event'])
                        ekey['trial_event_type'] = 'delay'
                        ekey['trial_event_time'] = t.state_times[d]
                        ekey['duration'] = gui['DelayPeriod'][0]
    
                        if math.isnan(ekey['duration']) and last_dur is None:
                            log.warning('... {} bad duration, no last_edur'.format(i, last_dur))
                            ekey['duration'] = 0.0  # FIXDUR: cross-trial check
                            rows['corrected_trial_event'].append(ekey)
    
                        elif math.isnan(ekey['duration']) and last_dur is not None:
                            log.warning('... {} duration using last_edur {}'.format(i, last_dur))
                            ekey['duration'] = last_dur
                            rows['corrected_trial_event'].append(ekey)
    
                        else:
                            last_dur = ekey['duration']  # only track 'good' values.
    
                        log.debug('delay event duration: {}'.format(ekey['duration']))
                        rows['trial_event'].append(ekey)
                         
                    #
                    # Add 'go' event
                    #
                    log.debug('BehaviorIngest.make(): go')
    
                    ekey = dict(tkey)
                    responseindex = np.where(t.state_data == states['ResponseCue'])[0]
    
                    ekey['trial_event_id'] = len(rows['trial_event'])
                    ekey['trial_event_type'] = 'go'
                    ekey['trial_event_time'] = t.state_times[responseindex][0]
                    ekey['duration'] = gui['AnswerPeriod'][0]
    
                    if math.isnan(ekey['duration']):
                        log.debug('BehaviorIngest.make(): fixing go duration')
                        ekey['duration'] = 0.0  # FIXDUR: lookup from previous trials
                        rows['corrected_trial_event'].append(ekey)
    
                    rows['trial_event'].append(ekey)
                
                    #
                    # Add 'trialEnd' events
                    #

                    log.debug('BehaviorIngest.make(): trialend events')

                    last_dur = None
                    trialendindex = np.where(t.state_data == states['TrialEnd'])[0]

                    ekey = dict(tkey)
                    ekey['trial_event_id'] = len(rows['trial_event'])
                    ekey['trial_event_type'] = 'trialend'
                    ekey['trial_event_time'] = t.state_times[trialendindex][0]
                    ekey['duration'] = 0.0
    
                    rows['trial_event'].append(ekey)
                    
                    #
                    # Add lick events
                    #
                       
                    lickleft = np.where(t.event_data == 69)[0]
                    log.debug('... lickleft: {r}'.format(r=str(lickleft)))
    
                    action_event_count = len(rows['action_event'])
                    if len(lickleft):
                        [rows['action_event'].append(
                                dict(tkey, action_event_id=action_event_count+idx,
                                     action_event_type='left lick',
                                     action_event_time=t.event_times[l]))
                        for idx, l in enumerate(lickleft)]
    
                    lickright = np.where(t.event_data == 71)[0]
                    log.debug('... lickright: {r}'.format(r=str(lickright)))
    
                    action_event_count = len(rows['action_event'])
                    if len(lickright):
                        [rows['action_event'].append(
                                dict(tkey, action_event_id=action_event_count+idx,
                                     action_event_type='right lick',
                                     action_event_time=t.event_times[r]))
                        for idx, r in enumerate(lickright)]
                    
                    # end of trial loop..    
        
                    # Session Insertion                     
                    log.info('BehaviorIngest.make(): adding session record')
                    skey['session_date'] = df_wr_row[1].Date 
                    skey['rig'] = 'Old Recording rig'
                    skey['username']  = '******'
                    experiment.Session().insert1(skey,skip_duplicates=True)

                # Behavior Insertion                

                log.info('BehaviorIngest.make(): ... experiment.Session.Trial')
                experiment.SessionTrial().insert(
                        rows['trial'], ignore_extra_fields=True, allow_direct_insert=True)

                log.info('BehaviorIngest.make(): ... experiment.BehaviorTrial')
                experiment.BehaviorTrial().insert(
                        rows['behavior_trial'], ignore_extra_fields=True,
                        allow_direct_insert=True)

                log.info('BehaviorIngest.make(): ... experiment.TrialNote')
                experiment.TrialNote().insert(
                        rows['trial_note'], ignore_extra_fields=True,
                        allow_direct_insert=True)

                log.info('BehaviorIngest.make(): ... experiment.TrialEvent')
                experiment.TrialEvent().insert(
                        rows['trial_event'], ignore_extra_fields=True,
                        allow_direct_insert=True, skip_duplicates=True)
        
#        log.info('BehaviorIngest.make(): ... CorrectedTrialEvents')
#        BehaviorIngest().CorrectedTrialEvents().insert(
#            rows['corrected_trial_event'], ignore_extra_fields=True,
#            allow_direct_insert=True)

                log.info('BehaviorIngest.make(): ... experiment.ActionEvent')
                experiment.ActionEvent().insert(
                        rows['action_event'], ignore_extra_fields=True,
                        allow_direct_insert=True)
                            
#%% for ingest tracking                
                if IDs:
                    print('subject started:')
                    print(IDs.keys())
                    print(IDs.values())
                    
                rigpath_tracking_1 = 'E:/Projects/Ablation/datajoint/video/'
                rigpath_tracking_2 = subject_now
                VideoDate1 = str(df_wr_row[1].VideoDate)
                if len(VideoDate1)==5:
                    VideoDate = '0'+ VideoDate1
                elif len(VideoDate1)==7:
                    VideoDate = '0'+ VideoDate1
                rigpath_tracking_3 = rigpath_tracking_1 + rigpath_tracking_2 + '/' + rigpath_tracking_2 + '_'+ VideoDate + '_front'
                
                rigpath_tracking = pathlib.Path(rigpath_tracking_3)
                
                #df_surgery = pd.read_csv(dj.config['locations.metadata']+'Surgery.csv')
                if IDs == None:
                    IDs = {k: v for k, v in zip(*lab.WaterRestriction().fetch('water_restriction_number', 'subject_id'))}   
                
                h2o = subject_now
                session = df_wr_row[1].Date
                trials = (experiment.SessionTrial() & session).fetch('trial')
                
                log.info('got session: {} ({} trials)'.format(session, len(trials)))
                
                #sdate = session['session_date']
                #sdate_sml = date_now #"{}{:02d}{:02d}".format(sdate.year, sdate.month, sdate.day)

                paths = rigpath_tracking
                devices = tracking.TrackingDevice().fetch(as_dict=True)
                
                # paths like: <root>/<h2o>/YYYY-MM-DD/tracking
                tracking_files = []
                for d in (d for d in devices):
                    tdev = d['tracking_device']
                    tpos = d['tracking_position']
                    tdat = paths
                    log.info('checking {} for tracking data'.format(tdat))               
                    
                    
#                    if not tpath.exists():
#                        log.warning('tracking path {} n/a - skipping'.format(tpath))
#                        continue
#                    
#                    camtrial = '{}_{}_{}.txt'.format(h2o, sdate_sml, tpos)
#                    campath = tpath / camtrial
#                    
#                    log.info('trying camera position trial map: {}'.format(campath))
#                    
#                    if not campath.exists():
#                        log.info('skipping {} - does not exist'.format(campath))
#                        continue
#                    
#                    tmap = load_campath(campath)  # file:trial
#                    n_tmap = len(tmap)
#                    log.info('loading tracking data for {} trials'.format(n_tmap))

                    i = 0                    
                    VideoTrialNum = df_wr_row[1].VideoTrialNum
                    
                    #tpath = pathlib.Path(tdat, h2o, VideoDate, 'tracking')
                    ppp = list(range(0,VideoTrialNum))
                    for tt in reversed(range(VideoTrialNum)):  # load tracking for trial
                        
                        i += 1
#                        if i % 50 == 0:
#                            log.info('item {}/{}, trial #{} ({:.2f}%)'
#                                     .format(i, n_tmap, t, (i/n_tmap)*100))
#                        else:
#                            log.debug('item {}/{}, trial #{} ({:.2f}%)'
#                                      .format(i, n_tmap, t, (i/n_tmap)*100))
        
                        # ex: dl59_side_1-0000.csv / h2o_position_tn-0000.csv
                        tfile = '{}_{}_{}_{}-*.csv'.format(h2o, VideoDate ,tpos, tt)
                        tfull = list(tdat.glob(tfile))
                        if not tfull or len(tfull) > 1:
                            log.info('file mismatch: file: {} trial: ({})'.format(
                                tt, tfull))
                            continue
        
                        tfull = tfull[-1]
                        trk = load_tracking(tfull)
                        
        
                        recs = {}
                        
                        #key_source = experiment.Session - tracking.Tracking                        
                        rec_base = dict(trial=ppp[tt], tracking_device=tdev)
                        #print(rec_base)
                        for k in trk:
                            if k == 'samples':
                                recs['tracking'] = {
                                    'subject_id' : skey['subject_id'], 
                                    'session' : skey['session'],
                                    **rec_base,
                                    'tracking_samples': len(trk['samples']['ts']),
                                }
                                
                            else:
                                rec = dict(rec_base)
        
                                for attr in trk[k]:
                                    rec_key = '{}_{}'.format(k, attr)
                                    rec[rec_key] = np.array(trk[k][attr])
        
                                recs[k] = rec
                        
                        
                        tracking.Tracking.insert1(
                            recs['tracking'], allow_direct_insert=True)
                        
                        #if len(recs['nose']) > 3000:
                            #continue
                            
                        recs['nose'] = {
                                'subject_id' : skey['subject_id'], 
                                'session' : skey['session'],
                                **recs['nose'],
                                }
                        
                        #print(recs['nose']['nose_x'])
                        if 'nose' in recs:
                            tracking.Tracking.NoseTracking.insert1(
                                recs['nose'], allow_direct_insert=True)
                            
                        recs['tongue_mid'] = {
                                'subject_id' : skey['subject_id'], 
                                'session' : skey['session'],
                                **recs['tongue_mid'],
                                }
        
                        if 'tongue_mid' in recs:
                            tracking.Tracking.TongueTracking.insert1(
                                recs['tongue_mid'], allow_direct_insert=True)
                            
                        recs['jaw'] = {
                                'subject_id' : skey['subject_id'], 
                                'session' : skey['session'],
                                **recs['jaw'],
                                }
        
                        if 'jaw' in recs:
                            tracking.Tracking.JawTracking.insert1(
                                recs['jaw'], allow_direct_insert=True)
                        
                        recs['tongue_left'] = {
                                'subject_id' : skey['subject_id'], 
                                'session' : skey['session'],
                                **recs['tongue_left'],
                                }
        
                        if 'tongue_left' in recs:
                            tracking.Tracking.LeftTongueTracking.insert1(
                                recs['tongue_left'], allow_direct_insert=True)
                            
                        recs['tongue_right'] = {
                                'subject_id' : skey['subject_id'], 
                                'session' : skey['session'],
                                **recs['tongue_right'],
                                }
        
                        if 'tongue_right' in recs:
                            tracking.Tracking.RightTongueTracking.insert1(
                                recs['tongue_right'], allow_direct_insert=True)
#                            fmap = {'paw_left_x': 'left_paw_x',  # remap field names
#                                    'paw_left_y': 'left_paw_y',
#                                    'paw_left_likelihood': 'left_paw_likelihood'}
        
#                            tracking.Tracking.LeftPawTracking.insert1({
#                                **{k: v for k, v in recs['paw_left'].items()
#                                   if k not in fmap},
#                                **{fmap[k]: v for k, v in recs['paw_left'].items()
#                                   if k in fmap}}, allow_direct_insert=True)
                        
                        recs['right_lickport'] = {
                                'subject_id' : skey['subject_id'], 
                                'session' : skey['session'],
                                **recs['right_lickport'],
                                }
                        
                        if 'right_lickport' in recs:
                            tracking.Tracking.RightLickPortTracking.insert1(
                                recs['right_lickport'], allow_direct_insert=True)
#                            fmap = {'paw_right_x': 'right_paw_x',  # remap field names
#                                    'paw_right_y': 'right_paw_y',
#                                    'paw_right_likelihood': 'right_paw_likelihood'}
#        
#                            tracking.Tracking.RightPawTracking.insert1({
#                                **{k: v for k, v in recs['paw_right'].items()
#                                   if k not in fmap},
#                                **{fmap[k]: v for k, v in recs['paw_right'].items()
#                                   if k in fmap}}, allow_direct_insert=True)
                        
                        recs['left_lickport'] = {
                                'subject_id' : skey['subject_id'], 
                                'session' : skey['session'],
                                **recs['left_lickport'],
                                }
                        
                        if 'left_lickport' in recs:
                            tracking.Tracking.LeftLickPortTracking.insert1(
                                recs['left_lickport'], allow_direct_insert=True)
        
#                        tracking_files.append({**key, 'trial': tmap[t], 'tracking_device': tdev,
#                             'tracking_file': str(tfull.relative_to(tdat))})
#        
#                    log.info('... completed {}/{} items.'.format(i, n_tmap))
#        
#                self.insert1(key)
#                self.TrackingFile.insert(tracking_files)
#                   
                            
                        tracking.VideoFiducialsTrial.populate()
                        bottom_tongue.Camera_pixels.populate()
                        print('start!')               
                        bottom_tongue.VideoTongueTrial.populate()
                        sessiontrialdata={              'subject_id':skey['subject_id'],
                                                        'session':skey['session'],
                                                        'trial': tt
                                                        }
                        if len(bottom_tongue.VideoTongueTrial* experiment.Session & experiment.BehaviorTrial  & 'session_date = "'+str(file['session_date'])+'"' &{'trial':tt})==0:
                            print('trial couldn''t be exported, deleting trial')
                            print(tt)
                            dj.config['safemode'] = False
                            (experiment.SessionTrial()&sessiontrialdata).delete()
                            dj.config['safemode'] = True  
                        
                        
                log.info('... done.')
コード例 #8
0
    def make(self, key):
        log.info('BehaviorIngest.make(): key: {key}'.format(key=key))

        subject_id = key['subject_id']
        h2o = (lab.WaterRestriction() & {
            'subject_id': subject_id
        }).fetch1('water_restriction_number')

        date = key['session_date']
        datestr = date.strftime('%Y%m%d')
        log.info('h2o: {h2o}, date: {d}'.format(h2o=h2o, d=datestr))

        # session record key
        skey = {}
        skey['subject_id'] = subject_id
        skey['session_date'] = date
        skey['username'] = self.get_session_user()

        # File paths conform to the pattern:
        # dl7/TW_autoTrain/Session Data/dl7_TW_autoTrain_20180104_132813.mat
        # which is, more generally:
        # {h2o}/{training_protocol}/Session Data/{h2o}_{training protocol}_{YYYYMMDD}_{HHMMSS}.mat
        root = pathlib.Path(key['rig_data_path'],
                            os.path.dirname(key['subpath']))
        path = root / '{h2o}_*_{d}*.mat'.format(h2o=h2o, d=datestr)

        log.info('rigpath {p}'.format(p=path))

        matches = sorted(
            root.glob('{h2o}_*_{d}*.mat'.format(h2o=h2o, d=datestr)))
        if matches:
            log.info('found files: {}, this is the rig'.format(matches))
            skey['rig'] = key['rig']
        else:
            log.info('no file matches found in {p}'.format(p=path))

        if not len(matches):
            log.warning('no file matches found for {h2o} / {d}'.format(
                h2o=h2o, d=datestr))
            return

        #
        # Find files & Check for split files
        # XXX: not checking rig.. 2+ sessions on 2+ rigs possible for date?
        #

        if len(matches) > 1:
            log.warning(
                'split session case detected for {h2o} on {date}'.format(
                    h2o=h2o, date=date))

        # session:date relationship is 1:1; skip if we have a session
        if experiment.Session() & skey:
            log.warning("Warning! session exists for {h2o} on {d}".format(
                h2o=h2o, d=date))
            return

        #
        # Prepare PhotoStim
        #
        photosti_duration = 0.5  # (s) Hard-coded here
        photostims = {
            4: {
                'photo_stim': 4,
                'photostim_device': 'OBIS470',
                'brain_location_name': 'left_alm',
                'duration': photosti_duration
            },
            5: {
                'photo_stim': 5,
                'photostim_device': 'OBIS470',
                'brain_location_name': 'right_alm',
                'duration': photosti_duration
            },
            6: {
                'photo_stim': 6,
                'photostim_device': 'OBIS470',
                'brain_location_name': 'both_alm',
                'duration': photosti_duration
            }
        }

        #
        # Extract trial data from file(s) & prepare trial loop
        #

        trials = zip()

        trial = namedtuple(  # simple structure to track per-trial vars
            'trial',
            ('ttype', 'stim', 'settings', 'state_times', 'state_names',
             'state_data', 'event_data', 'event_times'))

        for f in matches:

            if os.stat(f).st_size / 1024 < 1000:
                log.info('skipping file {f} - too small'.format(f=f))
                continue

            log.debug('loading file {}'.format(f))

            mat = spio.loadmat(f, squeeze_me=True)
            SessionData = mat['SessionData'].flatten()

            AllTrialTypes = SessionData['TrialTypes'][0]
            AllTrialSettings = SessionData['TrialSettings'][0]

            RawData = SessionData['RawData'][0].flatten()
            AllStateNames = RawData['OriginalStateNamesByNumber'][0]
            AllStateData = RawData['OriginalStateData'][0]
            AllEventData = RawData['OriginalEventData'][0]
            AllStateTimestamps = RawData['OriginalStateTimestamps'][0]
            AllEventTimestamps = RawData['OriginalEventTimestamps'][0]

            # verify trial-related data arrays are all same length
            assert (all(
                (x.shape[0] == AllStateTimestamps.shape[0]
                 for x in (AllTrialTypes, AllTrialSettings, AllStateNames,
                           AllStateData, AllEventData, AllEventTimestamps))))

            if 'StimTrials' in SessionData.dtype.fields:
                log.debug('StimTrials detected in session - will include')
                AllStimTrials = SessionData['StimTrials'][0]
                assert (AllStimTrials.shape[0] == AllStateTimestamps.shape[0])
            else:
                log.debug('StimTrials not detected in session - will skip')
                AllStimTrials = np.array([
                    None for i in enumerate(range(AllStateTimestamps.shape[0]))
                ])

            z = zip(AllTrialTypes, AllStimTrials, AllTrialSettings,
                    AllStateTimestamps, AllStateNames, AllStateData,
                    AllEventData, AllEventTimestamps)

            trials = chain(trials, z)  # concatenate the files

        trials = list(trials)

        # all files were internally invalid or size < 100k
        if not trials:
            log.warning('skipping date {d}, no valid files'.format(d=date))
            return

        #
        # Trial data seems valid; synthesize session id & add session record
        # XXX: note - later breaks can result in Sessions without valid trials
        #

        log.debug('synthesizing session ID')
        session = (dj.U().aggr(experiment.Session() & {
            'subject_id': subject_id
        },
                               n='max(session)').fetch1('n') or 0) + 1
        log.info('generated session id: {session}'.format(session=session))
        skey['session'] = session
        key = dict(key, **skey)

        #
        # Actually load the per-trial data
        #
        log.info('BehaviorIngest.make(): trial parsing phase')

        # lists of various records for batch-insert
        rows = {
            k: list()
            for k in ('trial', 'behavior_trial', 'trial_note', 'trial_event',
                      'corrected_trial_event', 'action_event', 'photostim',
                      'photostim_location', 'photostim_trial',
                      'photostim_trial_event')
        }

        i = -1
        for t in trials:

            #
            # Misc
            #

            t = trial(*t)  # convert list of items to a 'trial' structure
            i += 1  # increment trial counter

            log.debug('BehaviorIngest.make(): parsing trial {i}'.format(i=i))

            # covert state data names into a lookup dictionary
            #
            # names (seem to be? are?):
            #
            # Trigtrialstart
            # PreSamplePeriod
            # SamplePeriod
            # DelayPeriod
            # EarlyLickDelay
            # EarlyLickSample
            # ResponseCue
            # GiveLeftDrop
            # GiveRightDrop
            # GiveLeftDropShort
            # GiveRightDropShort
            # AnswerPeriod
            # Reward
            # RewardConsumption
            # NoResponse
            # TimeOut
            # StopLicking
            # StopLickingReturn
            # TrialEnd

            states = {k: (v + 1) for v, k in enumerate(t.state_names)}
            required_states = ('PreSamplePeriod', 'SamplePeriod',
                               'DelayPeriod', 'ResponseCue', 'StopLicking',
                               'TrialEnd')

            missing = list(k for k in required_states if k not in states)

            if len(missing):
                log.warning('skipping trial {i}; missing {m}'.format(
                    i=i, m=missing))
                continue

            gui = t.settings['GUI'].flatten()

            # ProtocolType - only ingest protocol >= 3
            #
            # 1 Water-Valve-Calibration 2 Licking 3 Autoassist
            # 4 No autoassist 5 DelayEnforce 6 SampleEnforce 7 Fixed
            #

            if 'ProtocolType' not in gui.dtype.names:
                log.warning(
                    'skipping trial {i}; protocol undefined'.format(i=i))
                continue

            protocol_type = gui['ProtocolType'][0]
            if gui['ProtocolType'][0] < 3:
                log.warning('skipping trial {i}; protocol {n} < 3'.format(
                    i=i, n=gui['ProtocolType'][0]))
                continue

            #
            # Top-level 'Trial' record
            #

            tkey = dict(skey)
            startindex = np.where(t.state_data == states['PreSamplePeriod'])[0]

            # should be only end of 1st StopLicking;
            # rest of data is irrelevant w/r/t separately ingested ephys
            endindex = np.where(t.state_data == states['StopLicking'])[0]

            log.debug('states\n' + str(states))
            log.debug('state_data\n' + str(t.state_data))
            log.debug('startindex\n' + str(startindex))
            log.debug('endindex\n' + str(endindex))

            if not (len(startindex) and len(endindex)):
                log.warning(
                    'skipping trial {i}: start/end index error: {s}/{e}'.
                    format(i=i, s=str(startindex), e=str(endindex)))
                continue

            try:
                tkey['trial'] = i
                tkey[
                    'trial_uid'] = i  # Arseny has unique id to identify some trials
                tkey['start_time'] = t.state_times[startindex][0]
                tkey['stop_time'] = t.state_times[endindex][0]
            except IndexError:
                log.warning(
                    'skipping trial {i}: error indexing {s}/{e} into {t}'.
                    format(i=i,
                           s=str(startindex),
                           e=str(endindex),
                           t=str(t.state_times)))
                continue

            log.debug('BehaviorIngest.make(): Trial().insert1')  # TODO msg
            log.debug('tkey' + str(tkey))
            rows['trial'].append(tkey)

            #
            # Specific BehaviorTrial information for this trial
            #

            bkey = dict(tkey)
            bkey['task'] = 'audio delay'  # hard-coded here
            bkey['task_protocol'] = 1  # hard-coded here

            # determine trial instruction
            trial_instruction = 'left'  # hard-coded here

            if gui['Reversal'][0] == 1:
                if t.ttype == 1:
                    trial_instruction = 'left'
                elif t.ttype == 0:
                    trial_instruction = 'right'
            elif gui['Reversal'][0] == 2:
                if t.ttype == 1:
                    trial_instruction = 'right'
                elif t.ttype == 0:
                    trial_instruction = 'left'

            bkey['trial_instruction'] = trial_instruction

            # determine early lick
            early_lick = 'no early'

            if (protocol_type >= 5 and 'EarlyLickDelay' in states
                    and np.any(t.state_data == states['EarlyLickDelay'])):
                early_lick = 'early'
            if (protocol_type > 5 and
                ('EarlyLickSample' in states
                 and np.any(t.state_data == states['EarlyLickSample']))):
                early_lick = 'early'

            bkey['early_lick'] = early_lick

            # determine outcome
            outcome = 'ignore'

            if ('Reward' in states
                    and np.any(t.state_data == states['Reward'])):
                outcome = 'hit'
            elif ('TimeOut' in states
                  and np.any(t.state_data == states['TimeOut'])):
                outcome = 'miss'
            elif ('NoResponse' in states
                  and np.any(t.state_data == states['NoResponse'])):
                outcome = 'ignore'

            bkey['outcome'] = outcome
            rows['behavior_trial'].append(bkey)

            #
            # Add 'protocol' note
            #
            nkey = dict(tkey)
            nkey['trial_note_type'] = 'protocol #'
            nkey['trial_note'] = str(protocol_type)
            rows['trial_note'].append(nkey)

            #
            # Add 'autolearn' note
            #
            nkey = dict(tkey)
            nkey['trial_note_type'] = 'autolearn'
            nkey['trial_note'] = str(gui['Autolearn'][0])
            rows['trial_note'].append(nkey)

            #
            # Add 'bitcode' note
            #
            if 'randomID' in gui.dtype.names:
                nkey = dict(tkey)
                nkey['trial_note_type'] = 'bitcode'
                nkey['trial_note'] = str(gui['randomID'][0])
                rows['trial_note'].append(nkey)

            #
            # Add presample event
            #
            log.debug('BehaviorIngest.make(): presample')

            ekey = dict(tkey)
            sampleindex = np.where(t.state_data == states['SamplePeriod'])[0]

            ekey['trial_event_id'] = len(rows['trial_event'])
            ekey['trial_event_type'] = 'presample'
            ekey['trial_event_time'] = t.state_times[startindex][0]
            ekey['duration'] = (t.state_times[sampleindex[0]] -
                                t.state_times[startindex])[0]

            if math.isnan(ekey['duration']):
                log.debug('BehaviorIngest.make(): fixing presample duration')
                ekey['duration'] = 0.0  # FIXDUR: lookup from previous trial

            rows['trial_event'].append(ekey)

            #
            # Add other 'sample' events
            #

            log.debug('BehaviorIngest.make(): sample events')

            last_dur = None

            for s in sampleindex:  # in protocol > 6 ~-> n>1
                # todo: batch events
                ekey = dict(tkey)
                ekey['trial_event_id'] = len(rows['trial_event'])
                ekey['trial_event_type'] = 'sample'
                ekey['trial_event_time'] = t.state_times[s]
                ekey['duration'] = gui['SamplePeriod'][0]

                if math.isnan(ekey['duration']) and last_dur is None:
                    log.warning(
                        '... trial {} bad duration, no last_edur'.format(
                            i, last_dur))
                    ekey['duration'] = 0.0  # FIXDUR: cross-trial check
                    rows['corrected_trial_event'].append(ekey)

                elif math.isnan(ekey['duration']) and last_dur is not None:
                    log.warning(
                        '... trial {} duration using last_edur {}'.format(
                            i, last_dur))
                    ekey['duration'] = last_dur
                    rows['corrected_trial_event'].append(ekey)

                else:
                    last_dur = ekey['duration']  # only track 'good' values.

                rows['trial_event'].append(ekey)

            #
            # Add 'delay' events
            #

            log.debug('BehaviorIngest.make(): delay events')

            last_dur = None
            delayindex = np.where(t.state_data == states['DelayPeriod'])[0]

            for d in delayindex:  # protocol > 6 ~-> n>1
                ekey = dict(tkey)
                ekey['trial_event_id'] = len(rows['trial_event'])
                ekey['trial_event_type'] = 'delay'
                ekey['trial_event_time'] = t.state_times[d]
                ekey['duration'] = gui['DelayPeriod'][0]

                if math.isnan(ekey['duration']) and last_dur is None:
                    log.warning('... {} bad duration, no last_edur'.format(
                        i, last_dur))
                    ekey['duration'] = 0.0  # FIXDUR: cross-trial check
                    rows['corrected_trial_event'].append(ekey)

                elif math.isnan(ekey['duration']) and last_dur is not None:
                    log.warning('... {} duration using last_edur {}'.format(
                        i, last_dur))
                    ekey['duration'] = last_dur
                    rows['corrected_trial_event'].append(ekey)

                else:
                    last_dur = ekey['duration']  # only track 'good' values.

                log.debug('delay event duration: {}'.format(ekey['duration']))
                rows['trial_event'].append(ekey)

            #
            # Add 'go' event
            #
            log.debug('BehaviorIngest.make(): go')

            ekey = dict(tkey)
            responseindex = np.where(t.state_data == states['ResponseCue'])[0]

            ekey['trial_event_id'] = len(rows['trial_event'])
            ekey['trial_event_type'] = 'go'
            ekey['trial_event_time'] = t.state_times[responseindex][0]
            ekey['duration'] = gui['AnswerPeriod'][0]

            if math.isnan(ekey['duration']):
                log.debug('BehaviorIngest.make(): fixing go duration')
                ekey['duration'] = 0.0  # FIXDUR: lookup from previous trials
                rows['corrected_trial_event'].append(ekey)

            rows['trial_event'].append(ekey)

            #
            # Add 'trialEnd' events
            #

            log.debug('BehaviorIngest.make(): trialend events')

            last_dur = None
            trialendindex = np.where(t.state_data == states['TrialEnd'])[0]

            ekey = dict(tkey)
            ekey['trial_event_id'] = len(rows['trial_event'])
            ekey['trial_event_type'] = 'trialend'
            ekey['trial_event_time'] = t.state_times[trialendindex][0]
            ekey['duration'] = 0.0

            rows['trial_event'].append(ekey)

            #
            # Add lick events
            #

            lickleft = np.where(t.event_data == 69)[0]
            log.debug('... lickleft: {r}'.format(r=str(lickleft)))

            action_event_count = len(rows['action_event'])
            if len(lickleft):
                [
                    rows['action_event'].append(
                        dict(tkey,
                             action_event_id=action_event_count + idx,
                             action_event_type='left lick',
                             action_event_time=t.event_times[l]))
                    for idx, l in enumerate(lickleft)
                ]

            lickright = np.where(t.event_data == 71)[0]
            log.debug('... lickright: {r}'.format(r=str(lickright)))

            action_event_count = len(rows['action_event'])
            if len(lickright):
                [
                    rows['action_event'].append(
                        dict(tkey,
                             action_event_id=action_event_count + idx,
                             action_event_type='right lick',
                             action_event_time=t.event_times[r]))
                    for idx, r in enumerate(lickright)
                ]

            # Photostim Events
            #
            # TODO:
            #
            # - base stimulation parameters:
            #
            #   - should be loaded elsewhere - where
            #   - actual ccf locations - cannot be known apriori apparently?
            #   - Photostim.Profile: what is? fix/add
            #
            # - stim data
            #
            #   - how retrieve power from file (didn't see) or should
            #     be statically coded here?
            #   - how encode stim type 6?
            #     - we have hemisphere as boolean or
            #     - but adding an event 4 and event 5 means querying
            #       is less straightforwrard (e.g. sessions with 5 & 6)

            if t.stim:
                log.info('BehaviorIngest.make(): t.stim == {}'.format(t.stim))
                rows['photostim_trial'].append(tkey)
                delay_period_idx = np.where(
                    t.state_data == states['DelayPeriod'])[0][0]
                rows['photostim_trial_event'].append(
                    dict(tkey,
                         **photostims[t.stim],
                         photostim_event_id=len(rows['photostim_trial_event']),
                         photostim_event_time=t.state_times[delay_period_idx],
                         power=5.5))

            # end of trial loop.

        # Session Insertion

        log.info('BehaviorIngest.make(): adding session record')
        experiment.Session().insert1(skey)

        # Behavior Insertion

        log.info('BehaviorIngest.make(): bulk insert phase')

        log.info('BehaviorIngest.make(): saving ingest {d}'.format(d=key))
        self.insert1(key, ignore_extra_fields=True, allow_direct_insert=True)

        log.info('BehaviorIngest.make(): ... experiment.Session.Trial')
        experiment.SessionTrial().insert(rows['trial'],
                                         ignore_extra_fields=True,
                                         allow_direct_insert=True)

        log.info('BehaviorIngest.make(): ... experiment.BehaviorTrial')
        experiment.BehaviorTrial().insert(rows['behavior_trial'],
                                          ignore_extra_fields=True,
                                          allow_direct_insert=True)

        log.info('BehaviorIngest.make(): ... experiment.TrialNote')
        experiment.TrialNote().insert(rows['trial_note'],
                                      ignore_extra_fields=True,
                                      allow_direct_insert=True)

        log.info('BehaviorIngest.make(): ... experiment.TrialEvent')
        experiment.TrialEvent().insert(rows['trial_event'],
                                       ignore_extra_fields=True,
                                       allow_direct_insert=True,
                                       skip_duplicates=True)

        log.info('BehaviorIngest.make(): ... CorrectedTrialEvents')
        BehaviorIngest().CorrectedTrialEvents().insert(
            rows['corrected_trial_event'],
            ignore_extra_fields=True,
            allow_direct_insert=True)

        log.info('BehaviorIngest.make(): ... experiment.ActionEvent')
        experiment.ActionEvent().insert(rows['action_event'],
                                        ignore_extra_fields=True,
                                        allow_direct_insert=True)

        BehaviorIngest.BehaviorFile().insert(
            (dict(key, behavior_file=f.name) for f in matches),
            ignore_extra_fields=True,
            allow_direct_insert=True)

        # Photostim Insertion

        photostim_ids = set(
            [r['photo_stim'] for r in rows['photostim_trial_event']])
        if photostim_ids:
            log.info('BehaviorIngest.make(): ... experiment.Photostim')
            experiment.Photostim.insert(
                (dict(skey, **photostims[stim]) for stim in photostim_ids),
                ignore_extra_fields=True)

        log.info('BehaviorIngest.make(): ... experiment.PhotostimTrial')
        experiment.PhotostimTrial.insert(rows['photostim_trial'],
                                         ignore_extra_fields=True,
                                         allow_direct_insert=True)

        log.info('BehaviorIngest.make(): ... experiment.PhotostimTrialEvent')
        experiment.PhotostimEvent.insert(rows['photostim_trial_event'],
                                         ignore_extra_fields=True,
                                         allow_direct_insert=True)