Exemple #1
0
def undo_amplitude_scaling():

    amp_scale = 1 / 3.01

    units2fix = ephys.Unit & FixedAmpUnit  # only fix those units that underwent fix_0007
    units2fix = units2fix - (UndoFixedAmpUnit & 'fixed=1'
                             )  # exclude those that were already fixed

    if not units2fix:
        return

    # safety check, no jrclust results and no npx 1.0
    assert len(units2fix & 'clustering_method LIKE "jrclust%"') == 0
    assert len(units2fix.proj() * ephys.ProbeInsertion
               & 'probe_type LIKE "neuropixels 1.0%"') == 0

    fix_hist_key = {
        'fix_name': pathlib.Path(__file__).name,
        'fix_timestamp': datetime.now()
    }
    FixHistory.insert1(fix_hist_key)

    for unit in tqdm(units2fix.proj('unit_amp').fetch(as_dict=True)):
        amp = unit.pop('unit_amp')
        with dj.conn().transaction:
            (ephys.Unit & unit)._update('unit_amp', amp * amp_scale)
            FixedAmpUnit.insert1({
                **fix_hist_key,
                **unit, 'fixed': True,
                'scale': amp_scale
            })

    # delete cluster_quality figures and remake figures with updated unit_amp
    with dj.config(safemode=False):
        (report.ProbeLevelReport & units2fix).delete()
 def __init__(self):
     self.rel = NanTest()
     with dj.config(safemode=False):
         self.rel.delete()
     a = np.array([0, 1/3, np.nan, np.pi, np.nan])
     self.rel.insert(((i, value) for i, value in enumerate(a)))
     self.a = a
Exemple #3
0
    def test_query_caching(self):
        # initialize cache directory
        os.mkdir(os.path.expanduser('~/dj_query_cache'))

        with dj.config(query_cache=os.path.expanduser('~/dj_query_cache')):
            conn = schema.TTest3.connection
            # insert sample data and load cache
            schema.TTest3.insert(
                [dict(key=100 + i, value=200 + i) for i in range(2)])
            conn.set_query_cache(query_cache='main')
            cached_res = schema.TTest3().fetch()
            # attempt to insert while caching enabled
            try:
                schema.TTest3.insert(
                    [dict(key=200 + i, value=400 + i) for i in range(2)])
                assert False, 'Insert allowed while query caching enabled'
            except dj.DataJointError:
                conn.set_query_cache()
            # insert new data
            schema.TTest3.insert(
                [dict(key=600 + i, value=800 + i) for i in range(2)])
            # re-enable cache to access old results
            conn.set_query_cache(query_cache='main')
            previous_cache = schema.TTest3().fetch()
            # verify properly cached and how to refresh results
            assert all([c == p for c, p in zip(cached_res, previous_cache)])
            conn.set_query_cache()
            uncached_res = schema.TTest3().fetch()
            assert len(uncached_res) > len(cached_res)
            # purge query cache
            conn.purge_query_cache()

        # reset cache directory state (will fail if purge was unsuccessful)
        os.rmdir(os.path.expanduser('~/dj_query_cache'))
def test_contextmanager():
    """Testing context manager"""
    dj.config['arbitrary.stuff'] = 7

    with dj.config(arbitrary__stuff=10) as cfg:
        assert_true(dj.config['arbitrary.stuff'] == 10)
    assert_true(dj.config['arbitrary.stuff'] == 7)
def delete_non_published_records():

    with dj.config(safemode=False):

        logger.log(25, 'Deleting non-published probe insertions...')
        probe_insertion_table = QueryBuffer(ephys.ProbeInsertion)
        for key in tqdm((ephys.ProbeInsertion - public.PublicProbeInsertion -
                         ephys.DefaultCluster).fetch('KEY')):
            probe_insertion_table.add_to_queue1(key)
            if probe_insertion_table.flush_delete(quick=False, chunksz=100):
                logger.log(25, 'Deleted 100 probe insertions')

        probe_insertion_table.flush_delete(quick=False)
        logger.log(25, 'Deleted the rest of the probe insertions')

        logger.log(25, 'Deleting non-published sessions...')
        session_table = QueryBuffer(acquisition.Session)
        for key in tqdm((acquisition.Session - public.PublicSession -
                         behavior.TrialSet).fetch('KEY')):
            session_table.add_to_queue1(key)
            if session_table.flush_delete(quick=False, chunksz=100):
                logger.log(25, 'Deleted 100 sessions')

        session_table.flush_delete(quick=False)
        logger.log(25, 'Deleted the rest of the sessions')

        logger.log(25, 'Deleting non-published subjects...')
        subjs = subject.Subject & acquisition.Session

        for key in tqdm((subject.Subject - public.PublicSubjectUuid -
                         subjs.proj()).fetch('KEY')):
            (subject.Subject & key).delete()
Exemple #6
0
    def test_fetch_format(self):
        """test fetch_format='frame'"""
        with dj.config(fetch_format='frame'):
            # test if lists are both dicts
            list1 = sorted(self.subject.proj().fetch(as_dict=True),
                           key=itemgetter('subject_id'))
            list2 = sorted(self.subject.fetch(dj.key),
                           key=itemgetter('subject_id'))
            for l1, l2 in zip(list1, list2):
                assert_dict_equal(l1, l2,
                                  'Primary key is not returned correctly')

            # tests if pandas dataframe
            tmp = self.subject.fetch(order_by='subject_id')
            assert_true(isinstance(tmp, pandas.DataFrame))
            tmp = tmp.to_records()

            subject_notes, key, real_id = self.subject.fetch(
                'subject_notes', dj.key, 'real_id')

            np.testing.assert_array_equal(sorted(subject_notes),
                                          sorted(tmp['subject_notes']))
            np.testing.assert_array_equal(sorted(real_id),
                                          sorted(tmp['real_id']))
            list1 = sorted(key, key=itemgetter('subject_id'))
            for l1, l2 in zip(list1, list2):
                assert_dict_equal(l1, l2,
                                  'Primary key is not returned correctly')
 def setup_class(cls):
     cls.rel = NanTest()
     with dj.config(safemode=False):
         cls.rel.delete()
     a = np.array([0, 1/3, np.nan, np.pi, np.nan])
     cls.rel.insert(((i, value) for i, value in enumerate(a)))
     cls.a = a
Exemple #8
0
def delete_outdated_project_plots(project_name='MAP'):
    if {'project_name': project_name} not in ProjectLevelProbeTrack.proj():
        return

    plotted_track_count = (ProjectLevelProbeTrack & {
        'project_name': project_name
    }).fetch1('track_count')
    latest_track_count = SessionLevelProbeTrack.fetch(
        'probe_track_count').sum().astype(int)

    if plotted_track_count != latest_track_count:
        uuid_byte = (ProjectLevelProbeTrack & {
            'project_name': project_name
        }).proj(ub='CONCAT(tracks_plot , "")').fetch1('ub')
        print('project_name', project_name, 'uuid_byte:', str(uuid_byte))
        ext_key = {'hash': uuid.UUID(bytes=uuid_byte)}

        with dj.config(safemode=False):
            with ProjectLevelProbeTrack.connection.transaction:
                # delete the outdated Probe Tracks
                (ProjectLevelProbeTrack & {
                    'project_name': project_name
                }).delete()
                # delete from external store
                (schema.external['report_store']
                 & ext_key).delete(delete_external_files=True)
                print('Outdated ProjectLevelProbeTrack deleted')
    else:
        print('ProjectLevelProbeTrack is up-to-date')
Exemple #9
0
 def setup_class(cls):
     cls.rel = NanTest()
     with dj.config(safemode=False):
         cls.rel.delete()
     a = np.array([0, 1 / 3, np.nan, np.pi, np.nan])
     cls.rel.insert(((i, value) for i, value in enumerate(a)))
     cls.a = a
 def __init__(self):
     self.rel = NanTest()
     with dj.config(safemode=False):
         self.rel.delete()
     a = np.array([0, 1/3, np.nan, np.pi, np.nan])
     self.rel.insert(((i, value) for i, value in enumerate(a)))
     self.a = a
Exemple #11
0
def test_contextmanager():
    """Testing context manager"""
    dj.config['arbitrary.stuff'] = 7

    with dj.config(arbitrary__stuff=10) as cfg:
        assert_true(dj.config['arbitrary.stuff'] == 10)
    assert_true(dj.config['arbitrary.stuff'] == 7)
def process_daily_summary():

    with dj.config(safemode=False):
        print('Populating plotting.DailyLabSummary...')
        last_sessions = (reference.Lab.aggr(
            behavior_plotting.DailyLabSummary,
            last_session_time='max(last_session_time)')).fetch('KEY')
        (behavior_plotting.DailyLabSummary & last_sessions).delete()
        behavior_plotting.DailyLabSummary.populate(**kwargs)
Exemple #13
0
def _update_one_session(key):
    log.info('\n======================================================')
    log.info('Waveform update for key: {k}'.format(k=key))

    #
    # Find Ephys Recording
    #
    key = (experiment.Session & key).fetch1()
    sinfo = ((lab.WaterRestriction
              * lab.Subject.proj()
              * experiment.Session.proj(..., '-session_time')) & key).fetch1()

    rigpaths = get_ephys_paths()
    h2o = sinfo['water_restriction_number']

    sess_time = (datetime.min + key['session_time']).time()
    sess_datetime = datetime.combine(key['session_date'], sess_time)

    for rigpath in rigpaths:
        dpath, dglob = _get_sess_dir(rigpath, h2o, sess_datetime)
        if dpath is not None:
            break

    if dpath is not None:
        log.info('Found session folder: {}'.format(dpath))
    else:
        log.warning('Error - No session folder found for {}/{}. Skipping...'.format(h2o, key['session_date']))
        return False

    try:
        clustering_files = _match_probe_to_ephys(h2o, dpath, dglob)
    except FileNotFoundError as e:
        log.warning(str(e) + '. Skipping...')
        return False

    with ephys.Unit.connection.transaction:
        for probe_no, (f, cluster_method, npx_meta) in clustering_files.items():
            try:
                log.info('------ Start loading clustering results for probe: {} ------'.format(probe_no))
                loader = cluster_loader_map[cluster_method]
                dj.conn().ping()
                _add_spike_sites_and_depths(loader(sinfo, *f), probe_no, npx_meta, rigpath)
            except (ProbeInsertionError, ClusterMetricError, FileNotFoundError) as e:
                dj.conn().cancel_transaction()  # either successful fix of all probes, or none at all
                if isinstance(e, ProbeInsertionError):
                    log.warning('Probe Insertion Error: \n{}. \nSkipping...'.format(str(e)))
                else:
                    log.warning('Error: {}'.format(str(e)))
                return False

        with dj.config(safemode=False):
            (ephys.UnitCellType & key).delete()

    return True
Exemple #14
0
def delete_outdated_session_plots():

    # ------------- SessionLevelProbeTrack ----------------

    sess_probe_hist = experiment.Session.aggr(histology.LabeledProbeTrack, probe_hist_count='count(*)')
    oudated_sess_probe = SessionLevelProbeTrack * sess_probe_hist & 'probe_track_count != probe_hist_count'

    uuid_bytes = (SessionLevelProbeTrack & oudated_sess_probe.proj()).proj(ub='(session_tracks_plot)').fetch('ub')

    if len(uuid_bytes):
        ext_keys = [{'hash': uuid.UUID(bytes=uuid_byte)} for uuid_byte in uuid_bytes]

        with dj.config(safemode=False):
            with SessionLevelProbeTrack.connection.transaction:
                # delete the outdated Probe Tracks
                (SessionLevelProbeTrack & oudated_sess_probe.proj()).delete()
                # delete from external store
                (schema.external['report_store'] & ext_keys).delete(delete_external_files=True)
                print('{} outdated SessionLevelProbeTrack deleted'.format(len(uuid_bytes)))
    else:
        print('All SessionLevelProbeTrack are up-to-date')

    # ------------- SessionLevelCDReport ----------------

    sess_probe = experiment.Session.aggr(ephys.ProbeInsertion, current_probe_count='count(*)')
    oudated_sess_probe = SessionLevelCDReport * sess_probe & 'cd_probe_count != current_probe_count'

    uuid_bytes = (SessionLevelCDReport & oudated_sess_probe.proj()).proj(ub='(coding_direction)').fetch('ub')

    if len(uuid_bytes):
        ext_keys = [{'hash': uuid.UUID(bytes=uuid_byte)} for uuid_byte in uuid_bytes]

        with dj.config(safemode=False):
            with SessionLevelCDReport.connection.transaction:
                # delete the outdated Probe Tracks
                (SessionLevelCDReport & oudated_sess_probe.proj()).delete()
                # delete from external store
                (schema.external['report_store'] & ext_keys).delete(delete_external_files=True)
                print('{} outdated SessionLevelCDReport deleted'.format(len(uuid_bytes)))
    else:
        print('All SessionLevelCDReport are up-to-date')
 def update(self, row, primary_keys={}, classdef=None, **kwargs):
     if classdef is None:
         classdef = self.classdef
     with closing(self.get_conn()) as conn:
         table = self.create_table(conn, classdef=classdef)
         try:
             ret = table.insert1(row, skip_duplicates=False, **kwargs)
         except DuplicateError:
             with dj.config(safemode=False):
                 (table & primary_keys).delete()
             ret = table.insert1(row, skip_duplicates=False, **kwargs)
     return ret
Exemple #16
0
    def update(self, key):
        print("Populating", key)

        avi_path = (Eye() & key).get_video_path()

        tracker = ManualTracker(avi_path)
        contours = (self.Frame() & key).fetch('contour', order_by='frame_id')
        tracker.contours = np.array(contours)
        tracker.contours_detected = np.array([e is not None for e in contours])
        tracker.backup_file = '/tmp/tracker_update_state{animal_id}-{session}-{scan_idx}.pkl'.format(
            **key)

        try:
            tracker.run()
        except Exception as e:
            print(str(e))
            answer = input(
                'Tracker crashed. Do you want to save the content anyway [y/n]?'
            ).lower()
            while answer not in ['y', 'n']:
                answer = input(
                    'Tracker crashed. Do you want to save the content anyway [y/n]?'
                ).lower()
            if answer == 'n':
                raise
        if input(
                'Do you want to delete and replace the existing entries? Type "YES" for acknowledgement.'
        ) == "YES":
            with dj.config(safemode=False):
                with self.connection.transaction:
                    (self & key).delete()

                    logtrace = tracker.mixing_constant.logtrace.astype(float)
                    self.insert1(
                        dict(key, min_lambda=logtrace[logtrace > 0].min()))
                    self.log_key(key)

                    frame = self.Frame()
                    parameters = self.Parameter()
                    for frame_id, ok, contour, params in tqdm(
                            zip(count(), tracker.contours_detected,
                                tracker.contours, tracker.parameter_iter()),
                            total=len(tracker.contours)):
                        assert frame_id == params['frame_id']
                        if ok:
                            frame.insert1(
                                dict(key, frame_id=frame_id, contour=contour))
                        else:
                            frame.insert1(dict(key, frame_id=frame_id))
                        parameters.insert1(dict(key, **params),
                                           ignore_extra_fields=True)
Exemple #17
0
def _fix_one_session(key):
    log.info('Running fix for session: {}'.format(key))
    # determine if this session's photostim is: `early-delay` or `late-delay`
    path = (behavior_ingest.BehaviorIngest.BehaviorFile
            & key).fetch('behavior_file')[0]
    h2o = (lab.WaterRestriction & key).fetch1('water_restriction_number')

    photostim_period = 'early-delay'
    rig_name = re.search('Recording(Rig\d)_', path)
    if re.match('SC', h2o) and rig_name:
        rig_name = rig_name.groups()[0]
        if rig_name == "Rig3":
            photostim_period = 'late-delay'

    invalid_photostim_trials = []
    for trial_key in (experiment.PhotostimTrial & key).fetch('KEY'):
        protocol_type = int(
            (experiment.TrialNote & trial_key
             & 'trial_note_type = "protocol #"').fetch1('trial_note'))
        autolearn = int(
            (experiment.TrialNote & trial_key
             & 'trial_note_type = "autolearn"').fetch1('trial_note'))

        if photostim_period == 'early-delay':
            valid_protocol = protocol_type == 5
        elif photostim_period == 'late-delay':
            valid_protocol = protocol_type > 4

        delay_duration = (experiment.TrialEvent & trial_key
                          & 'trial_event_type = "delay"').fetch(
                              'duration',
                              order_by='trial_event_time DESC',
                              limit=1)[0]

        if not (valid_protocol and autolearn == 4
                and delay_duration == Decimal('1.2')):
            # all criteria not met, this trial should not have been a photostim trial
            invalid_photostim_trials.append(trial_key)

    log.info('Deleting {} incorrectly labeled PhotostimTrial'.format(
        len(invalid_photostim_trials)))
    if len(invalid_photostim_trials):
        with dj.config(safemode=False):
            # delete invalid photostim trials
            (experiment.PhotostimTrial & invalid_photostim_trials).delete()
            # delete ProbeLevelPhotostimEffectReport figures associated with this session
            (report.ProbeLevelPhotostimEffectReport & key).delete()

    return True, invalid_photostim_trials
Exemple #18
0
def delete_empty_ingestion_tables():
    from pipeline.ingest import ephys as ephys_ingest
    from pipeline.ingest import tracking as tracking_ingest
    from pipeline.ingest import histology as histology_ingest

    with dj.config(safemode=False):
        try:
            (ephys_ingest.EphysIngest & (ephys_ingest.EphysIngest
                                         - ephys.ProbeInsertion).fetch('KEY')).delete()
            (tracking_ingest.TrackingIngest & (tracking_ingest.TrackingIngest
                                               - tracking.Tracking).fetch('KEY')).delete()
            (histology_ingest.HistologyIngest & (histology_ingest.HistologyIngest
                                                 - histology.ElectrodeCCFPosition).fetch('KEY')).delete()
        except OperationalError as e:  # in case of mysql deadlock - code: 1213
            if e.args[0] == 1213:
                pass
Exemple #19
0
def delete_entries_from_membership(pks_to_be_deleted):
    '''
    Delete entries from shadow membership membership_tables
    '''
    for t in MEMBERSHIP_TABLES:
        ingest_mod = t['dj_parent_table'].__module__
        table_name = t['dj_parent_table'].__name__

        mem_table_name = t['dj_current_table'].__name__

        print(f'Deleting from table {mem_table_name} ...')
        real_table = eval(ingest_mod.replace('ibl_pipeline.ingest.', '') + '.' + table_name)
        with dj.config(safemode=False):
            (t['dj_current_table'] &
             (real_table &
             [{t['dj_parent_uuid_name']:pk}
              for pk in pks_to_be_deleted if is_valid_uuid(pk)]).fetch('KEY')).delete()
Exemple #20
0
def apply_amplitude_scaling(insertion_keys={}):
    """
    This is a one-time operation only - April 2020
    Kilosort2 results from neuropixels probe 2.0 requires an additionally scaling factor of 3.01 applied
    to the unit amplitude and mean waveform.
    Future version of quality control pipeline will apply this scaling.
    """

    amp_scale = 3.01

    npx2_inserts = ephys.ProbeInsertion & insertion_keys & 'probe_type LIKE "neuropixels 2.0%"'

    units2fix = ephys.Unit * ephys.ClusteringLabel & npx2_inserts.proj(
    ) & 'quality_control = 1'
    units2fix = units2fix - (FixedAmpUnit & 'fixed=1'
                             )  # exclude those that were already fixed

    if not units2fix:
        return

    # safety check, no jrclust results
    assert len(units2fix & 'clustering_method LIKE "jrclust%"') == 0

    fix_hist_key = {
        'fix_name': pathlib.Path(__file__).name,
        'fix_timestamp': datetime.now()
    }
    FixHistory.insert1(fix_hist_key)

    for unit in tqdm(
            units2fix.proj('unit_amp', 'waveform').fetch(as_dict=True)):
        amp = unit.pop('unit_amp')
        wf = unit.pop('waveform')
        with dj.conn().transaction:
            (ephys.Unit & unit)._update('unit_amp', amp * amp_scale)
            (ephys.Unit & unit)._update('waveform', wf * amp_scale)
            FixedAmpUnit.insert1({
                **fix_hist_key,
                **unit, 'fixed': True,
                'scale': amp_scale
            })

    # delete cluster_quality figures and remake figures with updated unit_amp
    with dj.config(safemode=False):
        (report.ProbeLevelReport & npx2_inserts).delete()
    def test_manual_insert(self):
        """
        Test whether manual insert of gitkey works with different formats,
        """
        with dj.config(safemode=False):
            schema.Man().delete()
        # insert dictionary
        schema.Man().insert1(dict(idx=0, value=2.))

        # positional insert
        schema.Man().insert1((1, 2.))

        # fetch an np.void, modify, ans insert
        k = schema.Man().fetch.limit(1)()[0]
        k['idx'] = 2
        schema.Man().insert1(k)
        assert_true(len(schema.Man()) == len(schema.Man.GitKey()) == 3,
                    "Inserting with different datatypes did not work")
def process_cumulative_plots(backtrack_days=30):

    kwargs = dict(suppress_errors=True, display_progress=True)

    if mode != 'public':
        latest = subject.Subject.aggr(
            behavior_plotting.LatestDate,
            checking_ts='MAX(checking_ts)') * behavior_plotting.LatestDate & \
                [f'latest_date between curdate() - interval {backtrack_days} day and curdate()',
                (subject.Subject - subject.Death)] & \
                (subject.Subject & 'subject_nickname not like "%human%"').proj()
    else:
        latest = subject.Subject.aggr(
            behavior_plotting.LatestDate,
            checking_ts='MAX(checking_ts)') & \
                (subject.Subject & 'subject_nickname not like "%human%"').proj()

    subj_keys = (subject.Subject & behavior_plotting.CumulativeSummary
                 & latest).fetch('KEY')

    # delete and repopulate subject by subject
    with dj.config(safemode=False):
        for subj_key in tqdm(subj_keys, position=0):
            (behavior_plotting.CumulativeSummary & subj_key & latest).delete()
            print('populating...')
            behavior_plotting.CumulativeSummary.populate(
                latest & subj_key, **kwargs)
            # --- update the latest date of the subject -----
            # get the latest date of the CumulativeSummary of the subject
            subj_with_latest_date = (subject.Subject & subj_key).aggr(
                behavior_plotting.CumulativeSummary,
                latest_date='max(latest_date)')
            if len(subj_with_latest_date):
                new_date = subj_with_latest_date.fetch1('latest_date')
                current_subj = behavior_plotting.SubjectLatestDate & subj_key
                if len(current_subj):
                    current_subj._update('latest_date', new_date)
                else:
                    behavior_plotting.SubjectLatestDate.insert1(
                        subj_with_latest_date.fetch1())

        behavior_plotting.CumulativeSummary.populate(**kwargs)
Exemple #23
0
    def make(self, key):

        datasets = (data.FileRecord & key & 'repo_name LIKE "flatiron_%"' & {
            'exists': 1
        }).fetch('dataset_name')
        is_complete = bool(np.all([req_ds in datasets
                                   for req_ds in self.required_datasets])) \
            and bool(np.any(['spikes.times' in d for d in datasets]))

        if is_complete:
            self.insert1(key)
            with dj.config(safemode=False):
                (EphysMissingDataLog & key).delete_quick()

        else:
            for req_ds in self.required_datasets:
                if req_ds not in datasets:
                    EphysMissingDataLog.insert1(dict(**key,
                                                     missing_data=req_ds),
                                                skip_duplicates=True)
Exemple #24
0
    def _make_tuples(self, key):
        key['eye_time'], frames, key[
            'total_frames'] = self.grab_timestamps_and_frames(key)

        try:
            import cv2
            print('Drag window and print q when done')
            rg = CVROIGrabber(frames.mean(axis=2))
            rg.grab()
        except ImportError:
            rg = ROIGrabber(frames.mean(axis=2))

        with dj.config(display__width=50):
            print(EyeQuality())
        key['eye_quality'] = int(input("Enter the quality of the eye: "))
        key['eye_roi'] = rg.roi
        self.insert1(key)
        print('[Done]')
        if input('Do you want to stop? y/N: ') == 'y':
            self.connection.commit_transaction()
            raise PipelineException('User interrupted population.')
Exemple #25
0
    def update(self, key):
        print("Populating", key)

        avi_path = (Eye() & key).get_video_path()

        tracker = ManualTracker(avi_path)
        contours = (self.Frame() & key).fetch('contour', order_by='frame_id')
        tracker.contours = np.array(contours)
        tracker.contours_detected = np.array([e is not None for e in contours])
        tracker.backup_file = '/tmp/tracker_update_state{animal_id}-{session}-{scan_idx}.pkl'.format(**key)

        try:
            tracker.run()
        except Exception as e:
            print(str(e))
            answer = input('Tracker crashed. Do you want to save the content anyway [y/n]?').lower()
            while answer not in ['y', 'n']:
                answer = input('Tracker crashed. Do you want to save the content anyway [y/n]?').lower()
            if answer == 'n':
                raise
        if input('Do you want to delete and replace the existing entries? Type "YES" for acknowledgement.') == "YES":
            with dj.config(safemode=False):
                with self.connection.transaction:
                    (self & key).delete()

                    logtrace = tracker.mixing_constant.logtrace.astype(float)
                    self.insert1(dict(key, min_lambda=logtrace[logtrace > 0].min()))
                    self.log_key(key)

                    frame = self.Frame()
                    parameters = self.Parameter()
                    for frame_id, ok, contour, params in tqdm(zip(count(), tracker.contours_detected, tracker.contours,
                                                                  tracker.parameter_iter()),
                                                              total=len(tracker.contours)):
                        assert frame_id == params['frame_id']
                        if ok:
                            frame.insert1(dict(key, frame_id=frame_id, contour=contour))
                        else:
                            frame.insert1(dict(key, frame_id=frame_id))
                        parameters.insert1(dict(key, **params), ignore_extra_fields=True)
Exemple #26
0
def archive_electrode_histology(insertion_key, note='', delete=False):
    """
    For the specified "insertion_key" copy from histology.ElectrodeCCFPosition and histology.LabeledProbeTrack
     (and their respective part tables) to histology.ArchivedElectrodeHistology
    If "delete" == True - delete the records associated with the "insertion_key" from:
        + histology.ElectrodeCCFPosition
        + histology.LabeledProbeTrack
        + report.ProbeLevelDriftMap
        + report.ProbeLevelCoronalSlice
    """

    e_ccfs = {
        d['electrode']: d
        for d in (histology.ElectrodeCCFPosition.ElectrodePosition
                  & insertion_key).fetch(as_dict=True)
    }
    e_error_ccfs = {
        d['electrode']: d
        for d in (histology.ElectrodeCCFPosition.ElectrodePositionError
                  & insertion_key).fetch(as_dict=True)
    }
    e_ccfs_hash = dict_to_hash({**e_ccfs, **e_error_ccfs})

    if histology.ArchivedElectrodeHistology & {'archival_hash': e_ccfs_hash}:
        if delete:
            if dj.utils.user_choice(
                    'The specified ElectrodeCCF has already been archived!\nProceed with delete?'
            ) != 'yes':
                return
        else:
            print(
                'An identical set of the specified ElectrodeCCF has already been archived'
            )
            return

    archival_time = datetime.now()
    with histology.ArchivedElectrodeHistology.connection.transaction:
        histology.ArchivedElectrodeHistology.insert1({
            **insertion_key, 'archival_time':
            archival_time,
            'archival_note':
            note,
            'archival_hash':
            e_ccfs_hash
        })
        histology.ArchivedElectrodeHistology.ElectrodePosition.insert(
            (histology.ElectrodeCCFPosition.ElectrodePosition
             & insertion_key).proj(...,
                                   archival_time='"{}"'.format(archival_time)))
        histology.ArchivedElectrodeHistology.ElectrodePositionError.insert(
            (histology.ElectrodeCCFPosition.ElectrodePositionError
             & insertion_key).proj(...,
                                   archival_time='"{}"'.format(archival_time)))
        histology.ArchivedElectrodeHistology.LabeledProbeTrack.insert(
            (histology.LabeledProbeTrack & insertion_key).proj(
                ..., archival_time='"{}"'.format(archival_time)))
        histology.ArchivedElectrodeHistology.ProbeTrackPoint.insert(
            (histology.LabeledProbeTrack.Point & insertion_key).proj(
                ..., archival_time='"{}"'.format(archival_time)))

        if delete:
            with dj.config(safemode=False):
                (histology.ElectrodeCCFPosition & insertion_key).delete()
                (histology.LabeledProbeTrack & insertion_key).delete()
                (report.ProbeLevelDriftMap & insertion_key).delete()
                (report.ProbeLevelCoronalSlice & insertion_key).delete()
                (HistologyIngest & insertion_key).delete()
Exemple #27
0
def update_fields(real_schema, shadow_schema, table_name, pks, insert_to_table=False):
    '''
    Given a table and the primary key of real table, update all the fields that have discrepancy.
    Inputs: real_schema     : real schema module, e.g. reference
            shadow_schema   : shadow schema module, e.g. reference_ingest
            table_name      : string, name of a table, e.g. Subject
            pks             : list of dictionaries, primary keys of real table that contains modification
            insert_to_table : boolean, if True, log the update histolory in the table ibl_update.UpdateRecord
    '''

    real_table = getattr(real_schema, table_name)
    shadow_table = getattr(shadow_schema, table_name)

    secondary_fields = set(real_table.heading.secondary_attributes)
    ts_field = [f for f in secondary_fields
                if f.endswith('_ts')][0]
    fields_to_update = secondary_fields - {ts_field}

    for r in (real_table & pks).fetch('KEY'):

        pk_hash = UUID(dj.hash.key_hash(r))

        if not shadow_table & r:
            real_record = (real_table & r).fetch1()
            if insert_to_table:
                update_record = dict(
                    table=real_table.__module__ + '.' + real_table.__name__,
                    attribute='unknown',
                    pk_hash=pk_hash,
                    original_ts=real_record[ts_field],
                    update_ts=datetime.datetime.now(),
                    pk_dict=r,
                )
                update.UpdateRecord.insert1(update_record)
                update_record.pop('pk_dict')

                update_error_msg = 'Record does not exist in the shadow {}'.format(r)
                update_record_error = dict(
                    **update_record,
                    update_action_ts=datetime.datetime.now(),
                    update_error_msg=update_error_msg
                )
                update.UpdateError.insert1(update_record_error)

            print(update_error_msg)
            continue

        # if there are more than 1 record
        elif len(shadow_table & r) > 1:
            # delete the older record
            ts_field = [f for f in shadow_table.heading.names if '_ts' in f][0]
            lastest_record = dj.U().aggr(shadow_table & r, session_ts='max(session_ts)').fetch()

            with dj.config(safemode=False):
                ((shadow_table & r) - lastest_record).delete()

        shadow_record = (shadow_table & r).fetch1()
        real_record = (real_table & r).fetch1()

        for f in fields_to_update:
            if real_record[f] != shadow_record[f]:
                try:
                    (real_table & r)._update(f, shadow_record[f])
                    update_narrative = f'{table_name}.{f}: {shadow_record[f]} != {real_record[f]}'
                    print(update_narrative)
                    if insert_to_table:
                        update_record = dict(
                            table=real_table.__module__ + '.' + real_table.__name__,
                            attribute=f,
                            pk_hash=pk_hash,
                            original_ts=real_record[ts_field],
                            update_ts=shadow_record[ts_field],
                            pk_dict=r,
                            original_value=real_record[f],
                            updated_value=shadow_record[f],
                            update_narrative=update_narrative
                        )
                        update.UpdateRecord.insert1(update_record)

                except BaseException as e:
                    print(f'Error while updating record {r}: {str(e)}')
import datajoint as dj
from ibl_pipeline import subject, acquisition, data, behavior
from ibl_pipeline.ingest import acquisition as acquisition_ingest
from ibl_pipeline.ingest import data as data_ingest
from ibl_pipeline.ingest import alyxraw
import datetime
from oneibl.one import ONE
import numpy as np
from uuid import UUID

if __name__ == '__main__':

    with dj.config(safemode=False):
        uuids = ((acquisition_ingest.Session - behavior.TrialSet.proj())
                 & 'session_start_time > "2019-06-13"').fetch('session_uuid')
        uuid_str = [str(uuid) for uuid in uuids]
        for uuid in uuid_str:
            keys = (alyxraw.AlyxRaw.Field
                    & 'fvalue="{}"'.format(uuid)).fetch('KEY')
            (alyxraw.AlyxRaw & keys).delete()
            (alyxraw.AlyxRaw & {'uuid': UUID(uuid)}).delete()

            if len(acquisition_ingest.Session & {'session_uuid': UUID(uuid)}):
                subj_uuid, session_start_time = (acquisition_ingest.Session & {
                    'session_uuid': UUID(uuid)
                }).fetch1('subject_uuid', 'session_start_time')
            else:
                continue

            key = {
                'subject_uuid': subj_uuid,
Exemple #29
0
 def test_preview():
     with dj.config(display__limit=7):
         x = A().proj(a='id_a')
         s = x.preview()
         assert_equal(len(s.split('\n')), len(x) + 2)
 def test_preview():
     with dj.config(display__limit=7):
         x = A().proj(a='id_a')
         s = x.preview()
         assert_equal(len(s.split('\n')), len(x) + 2)
Exemple #31
0
def fix_session(session_key):
    paths = behavior_ingest.RigDataPath.fetch(as_dict=True)
    files = (behavior_ingest.BehaviorIngest *
             behavior_ingest.BehaviorIngest.BehaviorFile
             & session_key).fetch(as_dict=True, order_by='behavior_file asc')

    filelist = []
    for pf in [(p, f) for f in files for p in paths]:
        p, f = pf
        found = find_path(p['rig_data_path'], f['behavior_file'])
        if found:
            filelist.append(found)

    if len(filelist) != len(files):
        log.warning("behavior files missing in {} ({}/{}). skipping".format(
            session_key, len(filelist), len(files)))
        return False

    log.info('filelist: {}'.format(filelist))

    #
    # Prepare PhotoStim
    #
    photosti_duration = 0.5  # (s) Hard-coded here
    photostims = {
        4: {
            'photo_stim': 4,
            'photostim_device': 'OBIS470',
            'brain_location_name': 'left_alm',
            'duration': photosti_duration
        },
        5: {
            'photo_stim': 5,
            'photostim_device': 'OBIS470',
            'brain_location_name': 'right_alm',
            'duration': photosti_duration
        },
        6: {
            'photo_stim': 6,
            'photostim_device': 'OBIS470',
            'brain_location_name': 'both_alm',
            'duration': photosti_duration
        }
    }

    #
    # Extract trial data from file(s) & prepare trial loop
    #

    trials = zip()

    trial = namedtuple(  # simple structure to track per-trial vars
        'trial', ('ttype', 'stim', 'settings', 'state_times', 'state_names',
                  'state_data', 'event_data', 'event_times'))

    for f in filelist:

        if os.stat(f).st_size / 1024 < 1000:
            log.info('skipping file {f} - too small'.format(f=f))
            continue

        log.debug('loading file {}'.format(f))

        mat = spio.loadmat(f, squeeze_me=True)
        SessionData = mat['SessionData'].flatten()

        AllTrialTypes = SessionData['TrialTypes'][0]
        AllTrialSettings = SessionData['TrialSettings'][0]

        RawData = SessionData['RawData'][0].flatten()
        AllStateNames = RawData['OriginalStateNamesByNumber'][0]
        AllStateData = RawData['OriginalStateData'][0]
        AllEventData = RawData['OriginalEventData'][0]
        AllStateTimestamps = RawData['OriginalStateTimestamps'][0]
        AllEventTimestamps = RawData['OriginalEventTimestamps'][0]

        # verify trial-related data arrays are all same length
        assert (all(
            (x.shape[0] == AllStateTimestamps.shape[0]
             for x in (AllTrialTypes, AllTrialSettings, AllStateNames,
                       AllStateData, AllEventData, AllEventTimestamps))))

        if 'StimTrials' in SessionData.dtype.fields:
            log.debug('StimTrials detected in session - will include')
            AllStimTrials = SessionData['StimTrials'][0]
            assert (AllStimTrials.shape[0] == AllStateTimestamps.shape[0])
        else:
            log.debug('StimTrials not detected in session - will skip')
            AllStimTrials = np.array(
                [None for i in enumerate(range(AllStateTimestamps.shape[0]))])

        z = zip(AllTrialTypes, AllStimTrials, AllTrialSettings,
                AllStateTimestamps, AllStateNames, AllStateData, AllEventData,
                AllEventTimestamps)

        trials = chain(trials, z)  # concatenate the files

    trials = list(trials)

    # all files were internally invalid or size < 100k
    if not trials:
        log.warning('skipping ., no valid files')
        return False

    key = session_key
    skey = (experiment.Session & key).fetch1()

    #
    # Actually load the per-trial data
    #
    log.info('BehaviorIngest.make(): trial parsing phase')

    # lists of various records for batch-insert
    rows = {
        k: list()
        for k in ('trial', 'behavior_trial', 'trial_note', 'trial_event',
                  'corrected_trial_event', 'action_event', 'photostim',
                  'photostim_location', 'photostim_trial',
                  'photostim_trial_event')
    }

    i = -1
    for t in trials:

        #
        # Misc
        #

        t = trial(*t)  # convert list of items to a 'trial' structure
        i += 1  # increment trial counter

        log.debug('BehaviorIngest.make(): parsing trial {i}'.format(i=i))

        # covert state data names into a lookup dictionary
        #
        # names (seem to be? are?):
        #
        # Trigtrialstart
        # PreSamplePeriod
        # SamplePeriod
        # DelayPeriod
        # EarlyLickDelay
        # EarlyLickSample
        # ResponseCue
        # GiveLeftDrop
        # GiveRightDrop
        # GiveLeftDropShort
        # GiveRightDropShort
        # AnswerPeriod
        # Reward
        # RewardConsumption
        # NoResponse
        # TimeOut
        # StopLicking
        # StopLickingReturn
        # TrialEnd

        states = {k: (v + 1) for v, k in enumerate(t.state_names)}
        required_states = ('PreSamplePeriod', 'SamplePeriod', 'DelayPeriod',
                           'ResponseCue', 'StopLicking', 'TrialEnd')

        missing = list(k for k in required_states if k not in states)

        if len(missing):
            log.warning('skipping trial {i}; missing {m}'.format(i=i,
                                                                 m=missing))
            continue

        gui = t.settings['GUI'].flatten()

        # ProtocolType - only ingest protocol >= 3
        #
        # 1 Water-Valve-Calibration 2 Licking 3 Autoassist
        # 4 No autoassist 5 DelayEnforce 6 SampleEnforce 7 Fixed
        #

        if 'ProtocolType' not in gui.dtype.names:
            log.warning('skipping trial {i}; protocol undefined'.format(i=i))
            continue

        protocol_type = gui['ProtocolType'][0]
        if gui['ProtocolType'][0] < 3:
            log.warning('skipping trial {i}; protocol {n} < 3'.format(
                i=i, n=gui['ProtocolType'][0]))
            continue

        #
        # Top-level 'Trial' record
        #

        tkey = dict(skey)
        startindex = np.where(t.state_data == states['PreSamplePeriod'])[0]

        # should be only end of 1st StopLicking;
        # rest of data is irrelevant w/r/t separately ingested ephys
        endindex = np.where(t.state_data == states['StopLicking'])[0]

        log.debug('states\n' + str(states))
        log.debug('state_data\n' + str(t.state_data))
        log.debug('startindex\n' + str(startindex))
        log.debug('endindex\n' + str(endindex))

        if not (len(startindex) and len(endindex)):
            log.warning(
                'skipping trial {i}: start/end index error: {s}/{e}'.format(
                    i=i, s=str(startindex), e=str(endindex)))
            continue

        try:
            tkey['trial'] = i
            tkey[
                'trial_uid'] = i  # Arseny has unique id to identify some trials
            tkey['start_time'] = t.state_times[startindex][0]
            tkey['stop_time'] = t.state_times[endindex][0]
        except IndexError:
            log.warning(
                'skipping trial {i}: error indexing {s}/{e} into {t}'.format(
                    i=i,
                    s=str(startindex),
                    e=str(endindex),
                    t=str(t.state_times)))
            continue

        log.debug('BehaviorIngest.make(): Trial().insert1')  # TODO msg
        log.debug('tkey' + str(tkey))
        rows['trial'].append(tkey)

        #
        # Specific BehaviorTrial information for this trial
        #

        bkey = dict(tkey)
        bkey['task'] = 'audio delay'  # hard-coded here
        bkey['task_protocol'] = 1  # hard-coded here

        # determine trial instruction
        trial_instruction = 'left'  # hard-coded here

        if gui['Reversal'][0] == 1:
            if t.ttype == 1:
                trial_instruction = 'left'
            elif t.ttype == 0:
                trial_instruction = 'right'
        elif gui['Reversal'][0] == 2:
            if t.ttype == 1:
                trial_instruction = 'right'
            elif t.ttype == 0:
                trial_instruction = 'left'

        bkey['trial_instruction'] = trial_instruction

        # determine early lick
        early_lick = 'no early'

        if (protocol_type >= 5 and 'EarlyLickDelay' in states
                and np.any(t.state_data == states['EarlyLickDelay'])):
            early_lick = 'early'
        if (protocol_type > 5
                and ('EarlyLickSample' in states
                     and np.any(t.state_data == states['EarlyLickSample']))):
            early_lick = 'early'

        bkey['early_lick'] = early_lick

        # determine outcome
        outcome = 'ignore'

        if ('Reward' in states and np.any(t.state_data == states['Reward'])):
            outcome = 'hit'
        elif ('TimeOut' in states
              and np.any(t.state_data == states['TimeOut'])):
            outcome = 'miss'
        elif ('NoResponse' in states
              and np.any(t.state_data == states['NoResponse'])):
            outcome = 'ignore'

        bkey['outcome'] = outcome
        rows['behavior_trial'].append(bkey)

        #
        # Add 'protocol' note
        #
        nkey = dict(tkey)
        nkey['trial_note_type'] = 'protocol #'
        nkey['trial_note'] = str(protocol_type)
        rows['trial_note'].append(nkey)

        #
        # Add 'autolearn' note
        #
        nkey = dict(tkey)
        nkey['trial_note_type'] = 'autolearn'
        nkey['trial_note'] = str(gui['Autolearn'][0])
        rows['trial_note'].append(nkey)

        #
        # Add 'bitcode' note
        #
        if 'randomID' in gui.dtype.names:
            nkey = dict(tkey)
            nkey['trial_note_type'] = 'bitcode'
            nkey['trial_note'] = str(gui['randomID'][0])
            rows['trial_note'].append(nkey)

        #
        # Add presample event
        #
        log.debug('BehaviorIngest.make(): presample')

        ekey = dict(tkey)
        sampleindex = np.where(t.state_data == states['SamplePeriod'])[0]

        ekey['trial_event_id'] = len(rows['trial_event'])
        ekey['trial_event_type'] = 'presample'
        ekey['trial_event_time'] = t.state_times[startindex][0]
        ekey['duration'] = (t.state_times[sampleindex[0]] -
                            t.state_times[startindex])[0]

        if math.isnan(ekey['duration']):
            log.debug('BehaviorIngest.make(): fixing presample duration')
            ekey['duration'] = 0.0  # FIXDUR: lookup from previous trial

        rows['trial_event'].append(ekey)

        #
        # Add other 'sample' events
        #

        log.debug('BehaviorIngest.make(): sample events')

        last_dur = None

        for s in sampleindex:  # in protocol > 6 ~-> n>1
            # todo: batch events
            ekey = dict(tkey)
            ekey['trial_event_id'] = len(rows['trial_event'])
            ekey['trial_event_type'] = 'sample'
            ekey['trial_event_time'] = t.state_times[s]
            ekey['duration'] = gui['SamplePeriod'][0]

            if math.isnan(ekey['duration']) and last_dur is None:
                log.warning('... trial {} bad duration, no last_edur'.format(
                    i, last_dur))
                ekey['duration'] = 0.0  # FIXDUR: cross-trial check
                rows['corrected_trial_event'].append(ekey)

            elif math.isnan(ekey['duration']) and last_dur is not None:
                log.warning('... trial {} duration using last_edur {}'.format(
                    i, last_dur))
                ekey['duration'] = last_dur
                rows['corrected_trial_event'].append(ekey)

            else:
                last_dur = ekey['duration']  # only track 'good' values.

            rows['trial_event'].append(ekey)

        #
        # Add 'delay' events
        #

        log.debug('BehaviorIngest.make(): delay events')

        last_dur = None
        delayindex = np.where(t.state_data == states['DelayPeriod'])[0]

        for d in delayindex:  # protocol > 6 ~-> n>1
            ekey = dict(tkey)
            ekey['trial_event_id'] = len(rows['trial_event'])
            ekey['trial_event_type'] = 'delay'
            ekey['trial_event_time'] = t.state_times[d]
            ekey['duration'] = gui['DelayPeriod'][0]

            if math.isnan(ekey['duration']) and last_dur is None:
                log.warning('... {} bad duration, no last_edur'.format(
                    i, last_dur))
                ekey['duration'] = 0.0  # FIXDUR: cross-trial check
                rows['corrected_trial_event'].append(ekey)

            elif math.isnan(ekey['duration']) and last_dur is not None:
                log.warning('... {} duration using last_edur {}'.format(
                    i, last_dur))
                ekey['duration'] = last_dur
                rows['corrected_trial_event'].append(ekey)

            else:
                last_dur = ekey['duration']  # only track 'good' values.

            log.debug('delay event duration: {}'.format(ekey['duration']))
            rows['trial_event'].append(ekey)

        #
        # Add 'go' event
        #
        log.debug('BehaviorIngest.make(): go')

        ekey = dict(tkey)
        responseindex = np.where(t.state_data == states['ResponseCue'])[0]

        ekey['trial_event_id'] = len(rows['trial_event'])
        ekey['trial_event_type'] = 'go'
        ekey['trial_event_time'] = t.state_times[responseindex][0]
        ekey['duration'] = gui['AnswerPeriod'][0]

        if math.isnan(ekey['duration']):
            log.debug('BehaviorIngest.make(): fixing go duration')
            ekey['duration'] = 0.0  # FIXDUR: lookup from previous trials
            rows['corrected_trial_event'].append(ekey)

        rows['trial_event'].append(ekey)

        #
        # Add 'trialEnd' events
        #

        log.debug('BehaviorIngest.make(): trialend events')

        last_dur = None
        trialendindex = np.where(t.state_data == states['TrialEnd'])[0]

        ekey = dict(tkey)
        ekey['trial_event_id'] = len(rows['trial_event'])
        ekey['trial_event_type'] = 'trialend'
        ekey['trial_event_time'] = t.state_times[trialendindex][0]
        ekey['duration'] = 0.0

        rows['trial_event'].append(ekey)

        #
        # Add lick events
        #

        lickleft = np.where(t.event_data == 69)[0]
        log.debug('... lickleft: {r}'.format(r=str(lickleft)))

        action_event_count = len(rows['action_event'])
        if len(lickleft):
            [
                rows['action_event'].append(
                    dict(tkey,
                         action_event_id=action_event_count + idx,
                         action_event_type='left lick',
                         action_event_time=t.event_times[l]))
                for idx, l in enumerate(lickleft)
            ]

        lickright = np.where(t.event_data == 71)[0]
        log.debug('... lickright: {r}'.format(r=str(lickright)))

        action_event_count = len(rows['action_event'])
        if len(lickright):
            [
                rows['action_event'].append(
                    dict(tkey,
                         action_event_id=action_event_count + idx,
                         action_event_type='right lick',
                         action_event_time=t.event_times[r]))
                for idx, r in enumerate(lickright)
            ]

        # Photostim Events
        #
        # TODO:
        #
        # - base stimulation parameters:
        #
        #   - should be loaded elsewhere - where
        #   - actual ccf locations - cannot be known apriori apparently?
        #   - Photostim.Profile: what is? fix/add
        #
        # - stim data
        #
        #   - how retrieve power from file (didn't see) or should
        #     be statically coded here?
        #   - how encode stim type 6?
        #     - we have hemisphere as boolean or
        #     - but adding an event 4 and event 5 means querying
        #       is less straightforwrard (e.g. sessions with 5 & 6)

        if t.stim:
            log.debug('BehaviorIngest.make(): t.stim == {}'.format(t.stim))
            rows['photostim_trial'].append(tkey)
            delay_period_idx = np.where(
                t.state_data == states['DelayPeriod'])[0][0]
            rows['photostim_trial_event'].append(
                dict(tkey,
                     **photostims[t.stim],
                     photostim_event_id=len(rows['photostim_trial_event']),
                     photostim_event_time=t.state_times[delay_period_idx],
                     power=5.5))

        # end of trial loop.

    log.info('BehaviorIngest.make(): ... experiment.TrialEvent')

    fix_events = rows['trial_event']

    ref_events = (experiment.TrialEvent() & skey).fetch(
        order_by='trial, trial_event_id', as_dict=True)

    if False:
        for e in ref_events:
            log.debug('ref_events: t: {}, e: {}, event_type: {}'.format(
                e['trial'], e['trial_event_id'], e['trial_event_type']))
        for e in fix_events:
            log.debug('fix_events: t: {}, e: {}, type: {}'.format(
                e['trial'], e['trial_event_id'], e['trial_event_type']))

    log.info('deleting old events')

    with dj.config(safemode=False):

        log.info('... TrialEvent')
        (experiment.TrialEvent() & session_key).delete()

        log.info('... CorrectedTrialEvents')
        (behavior_ingest.BehaviorIngest.CorrectedTrialEvents()
         & session_key).delete_quick()

    log.info('adding new records')

    log.info('... experiment.TrialEvent')
    experiment.TrialEvent().insert(rows['trial_event'],
                                   ignore_extra_fields=True,
                                   allow_direct_insert=True,
                                   skip_duplicates=True)

    log.info('... CorrectedTrialEvents')
    behavior_ingest.BehaviorIngest.CorrectedTrialEvents().insert(
        rows['corrected_trial_event'],
        ignore_extra_fields=True,
        allow_direct_insert=True)

    return True
schema = dj.Schema(PREFIX + '_fetch_same', connection=dj.conn(**CONN_INFO))


@schema
class ProjData(dj.Manual):
    definition = """
    id : int
    ---
    resp : float
    sim  : float
    big : longblob
    blah : varchar(10)
    """


with dj.config(enable_python_native_blobs=True):
    ProjData().insert([{
        'id': 0,
        'resp': 20.33,
        'sim': 45.324,
        'big': 3,
        'blah': 'yes'
    }, {
        'id': 1,
        'resp': 94.3,
        'sim': 34.23,
        'big': {
            'key1': np.random.randn(20, 10)
        },
        'blah': 'si'
    }, {
Exemple #33
0
def test_suppress_dj_errors():
    """ test_suppress_dj_errors: dj errors suppressible w/o native py blobs """
    schema.schema.jobs.delete()
    with dj.config(enable_python_native_blobs=False):
        schema.ErrorClass.populate(reserve_jobs=True, suppress_errors=True)
    assert_true(len(schema.DjExceptionName()) == len(schema.schema.jobs) > 0)
Exemple #34
0
def _add_spike_sites_and_depths(data, probe, npx_meta, rigpath):

    sinfo = data['sinfo']
    skey = data['skey']
    method = data['method']
    spikes = data['spikes']
    units = data['units']
    spike_sites = data['spike_sites']
    spike_depths = data['spike_depths']
    creation_time = data['creation_time']

    clustering_label = data['clustering_label']

    log.info(
        '-- Start insertions for probe: {} - Clustering method: {} - Label: {}'
        .format(probe, method, clustering_label))

    # probe insertion key
    insertion_key, e_config_key = _gen_probe_insert(
        sinfo, probe, npx_meta, probe_insertion_exists=True)

    # remove noise clusters
    if method in ['jrclust_v3', 'jrclust_v4']:
        units, spikes, spike_sites, spike_depths = (v[i] for v, i in zip((
            units, spikes, spike_sites, spike_depths), repeat((units > 0))))

    q_electrodes = lab.ProbeType.Electrode * lab.ElectrodeConfig.Electrode & e_config_key
    site2electrode_map = {}
    for recorded_site in np.unique(spike_sites):
        shank, shank_col, shank_row, _ = npx_meta.shankmap['data'][
            recorded_site -
            1]  # subtract 1 because npx_meta shankmap is 0-indexed
        site2electrode_map[recorded_site] = (
            q_electrodes
            & {
                'shank': shank + 1,  # this is a 1-indexed pipeline
                'shank_col': shank_col + 1,
                'shank_row': shank_row + 1
            }).fetch1('KEY')

    spike_sites = np.array(
        [site2electrode_map[s]['electrode'] for s in spike_sites])
    unit_spike_sites = np.array(
        [spike_sites[np.where(units == u)] for u in set(units)])
    unit_spike_depths = np.array(
        [spike_depths[np.where(units == u)] for u in set(units)])

    archive_key = {
        **skey, 'insertion_number': probe,
        'clustering_method': method,
        'clustering_time': creation_time
    }

    # delete and reinsert with spike_sites and spike_depths
    for i, u in enumerate(set(units)):
        ukey = {**archive_key, 'unit': u}
        unit_data = (ephys.ArchivedClustering.Unit.proj(
            ..., '-spike_sites', '-spike_depths') & ukey).fetch1()

        with dj.config(safemode=False):
            (ephys.ArchivedClustering.Unit & ukey).delete()

        ephys.ArchivedClustering.Unit.insert1({
            **unit_data, 'spike_sites':
            unit_spike_sites[i],
            'spike_depths':
            unit_spike_depths[i]
        })