def undo_amplitude_scaling(): amp_scale = 1 / 3.01 units2fix = ephys.Unit & FixedAmpUnit # only fix those units that underwent fix_0007 units2fix = units2fix - (UndoFixedAmpUnit & 'fixed=1' ) # exclude those that were already fixed if not units2fix: return # safety check, no jrclust results and no npx 1.0 assert len(units2fix & 'clustering_method LIKE "jrclust%"') == 0 assert len(units2fix.proj() * ephys.ProbeInsertion & 'probe_type LIKE "neuropixels 1.0%"') == 0 fix_hist_key = { 'fix_name': pathlib.Path(__file__).name, 'fix_timestamp': datetime.now() } FixHistory.insert1(fix_hist_key) for unit in tqdm(units2fix.proj('unit_amp').fetch(as_dict=True)): amp = unit.pop('unit_amp') with dj.conn().transaction: (ephys.Unit & unit)._update('unit_amp', amp * amp_scale) FixedAmpUnit.insert1({ **fix_hist_key, **unit, 'fixed': True, 'scale': amp_scale }) # delete cluster_quality figures and remake figures with updated unit_amp with dj.config(safemode=False): (report.ProbeLevelReport & units2fix).delete()
def __init__(self): self.rel = NanTest() with dj.config(safemode=False): self.rel.delete() a = np.array([0, 1/3, np.nan, np.pi, np.nan]) self.rel.insert(((i, value) for i, value in enumerate(a))) self.a = a
def test_query_caching(self): # initialize cache directory os.mkdir(os.path.expanduser('~/dj_query_cache')) with dj.config(query_cache=os.path.expanduser('~/dj_query_cache')): conn = schema.TTest3.connection # insert sample data and load cache schema.TTest3.insert( [dict(key=100 + i, value=200 + i) for i in range(2)]) conn.set_query_cache(query_cache='main') cached_res = schema.TTest3().fetch() # attempt to insert while caching enabled try: schema.TTest3.insert( [dict(key=200 + i, value=400 + i) for i in range(2)]) assert False, 'Insert allowed while query caching enabled' except dj.DataJointError: conn.set_query_cache() # insert new data schema.TTest3.insert( [dict(key=600 + i, value=800 + i) for i in range(2)]) # re-enable cache to access old results conn.set_query_cache(query_cache='main') previous_cache = schema.TTest3().fetch() # verify properly cached and how to refresh results assert all([c == p for c, p in zip(cached_res, previous_cache)]) conn.set_query_cache() uncached_res = schema.TTest3().fetch() assert len(uncached_res) > len(cached_res) # purge query cache conn.purge_query_cache() # reset cache directory state (will fail if purge was unsuccessful) os.rmdir(os.path.expanduser('~/dj_query_cache'))
def test_contextmanager(): """Testing context manager""" dj.config['arbitrary.stuff'] = 7 with dj.config(arbitrary__stuff=10) as cfg: assert_true(dj.config['arbitrary.stuff'] == 10) assert_true(dj.config['arbitrary.stuff'] == 7)
def delete_non_published_records(): with dj.config(safemode=False): logger.log(25, 'Deleting non-published probe insertions...') probe_insertion_table = QueryBuffer(ephys.ProbeInsertion) for key in tqdm((ephys.ProbeInsertion - public.PublicProbeInsertion - ephys.DefaultCluster).fetch('KEY')): probe_insertion_table.add_to_queue1(key) if probe_insertion_table.flush_delete(quick=False, chunksz=100): logger.log(25, 'Deleted 100 probe insertions') probe_insertion_table.flush_delete(quick=False) logger.log(25, 'Deleted the rest of the probe insertions') logger.log(25, 'Deleting non-published sessions...') session_table = QueryBuffer(acquisition.Session) for key in tqdm((acquisition.Session - public.PublicSession - behavior.TrialSet).fetch('KEY')): session_table.add_to_queue1(key) if session_table.flush_delete(quick=False, chunksz=100): logger.log(25, 'Deleted 100 sessions') session_table.flush_delete(quick=False) logger.log(25, 'Deleted the rest of the sessions') logger.log(25, 'Deleting non-published subjects...') subjs = subject.Subject & acquisition.Session for key in tqdm((subject.Subject - public.PublicSubjectUuid - subjs.proj()).fetch('KEY')): (subject.Subject & key).delete()
def test_fetch_format(self): """test fetch_format='frame'""" with dj.config(fetch_format='frame'): # test if lists are both dicts list1 = sorted(self.subject.proj().fetch(as_dict=True), key=itemgetter('subject_id')) list2 = sorted(self.subject.fetch(dj.key), key=itemgetter('subject_id')) for l1, l2 in zip(list1, list2): assert_dict_equal(l1, l2, 'Primary key is not returned correctly') # tests if pandas dataframe tmp = self.subject.fetch(order_by='subject_id') assert_true(isinstance(tmp, pandas.DataFrame)) tmp = tmp.to_records() subject_notes, key, real_id = self.subject.fetch( 'subject_notes', dj.key, 'real_id') np.testing.assert_array_equal(sorted(subject_notes), sorted(tmp['subject_notes'])) np.testing.assert_array_equal(sorted(real_id), sorted(tmp['real_id'])) list1 = sorted(key, key=itemgetter('subject_id')) for l1, l2 in zip(list1, list2): assert_dict_equal(l1, l2, 'Primary key is not returned correctly')
def setup_class(cls): cls.rel = NanTest() with dj.config(safemode=False): cls.rel.delete() a = np.array([0, 1/3, np.nan, np.pi, np.nan]) cls.rel.insert(((i, value) for i, value in enumerate(a))) cls.a = a
def delete_outdated_project_plots(project_name='MAP'): if {'project_name': project_name} not in ProjectLevelProbeTrack.proj(): return plotted_track_count = (ProjectLevelProbeTrack & { 'project_name': project_name }).fetch1('track_count') latest_track_count = SessionLevelProbeTrack.fetch( 'probe_track_count').sum().astype(int) if plotted_track_count != latest_track_count: uuid_byte = (ProjectLevelProbeTrack & { 'project_name': project_name }).proj(ub='CONCAT(tracks_plot , "")').fetch1('ub') print('project_name', project_name, 'uuid_byte:', str(uuid_byte)) ext_key = {'hash': uuid.UUID(bytes=uuid_byte)} with dj.config(safemode=False): with ProjectLevelProbeTrack.connection.transaction: # delete the outdated Probe Tracks (ProjectLevelProbeTrack & { 'project_name': project_name }).delete() # delete from external store (schema.external['report_store'] & ext_key).delete(delete_external_files=True) print('Outdated ProjectLevelProbeTrack deleted') else: print('ProjectLevelProbeTrack is up-to-date')
def setup_class(cls): cls.rel = NanTest() with dj.config(safemode=False): cls.rel.delete() a = np.array([0, 1 / 3, np.nan, np.pi, np.nan]) cls.rel.insert(((i, value) for i, value in enumerate(a))) cls.a = a
def process_daily_summary(): with dj.config(safemode=False): print('Populating plotting.DailyLabSummary...') last_sessions = (reference.Lab.aggr( behavior_plotting.DailyLabSummary, last_session_time='max(last_session_time)')).fetch('KEY') (behavior_plotting.DailyLabSummary & last_sessions).delete() behavior_plotting.DailyLabSummary.populate(**kwargs)
def _update_one_session(key): log.info('\n======================================================') log.info('Waveform update for key: {k}'.format(k=key)) # # Find Ephys Recording # key = (experiment.Session & key).fetch1() sinfo = ((lab.WaterRestriction * lab.Subject.proj() * experiment.Session.proj(..., '-session_time')) & key).fetch1() rigpaths = get_ephys_paths() h2o = sinfo['water_restriction_number'] sess_time = (datetime.min + key['session_time']).time() sess_datetime = datetime.combine(key['session_date'], sess_time) for rigpath in rigpaths: dpath, dglob = _get_sess_dir(rigpath, h2o, sess_datetime) if dpath is not None: break if dpath is not None: log.info('Found session folder: {}'.format(dpath)) else: log.warning('Error - No session folder found for {}/{}. Skipping...'.format(h2o, key['session_date'])) return False try: clustering_files = _match_probe_to_ephys(h2o, dpath, dglob) except FileNotFoundError as e: log.warning(str(e) + '. Skipping...') return False with ephys.Unit.connection.transaction: for probe_no, (f, cluster_method, npx_meta) in clustering_files.items(): try: log.info('------ Start loading clustering results for probe: {} ------'.format(probe_no)) loader = cluster_loader_map[cluster_method] dj.conn().ping() _add_spike_sites_and_depths(loader(sinfo, *f), probe_no, npx_meta, rigpath) except (ProbeInsertionError, ClusterMetricError, FileNotFoundError) as e: dj.conn().cancel_transaction() # either successful fix of all probes, or none at all if isinstance(e, ProbeInsertionError): log.warning('Probe Insertion Error: \n{}. \nSkipping...'.format(str(e))) else: log.warning('Error: {}'.format(str(e))) return False with dj.config(safemode=False): (ephys.UnitCellType & key).delete() return True
def delete_outdated_session_plots(): # ------------- SessionLevelProbeTrack ---------------- sess_probe_hist = experiment.Session.aggr(histology.LabeledProbeTrack, probe_hist_count='count(*)') oudated_sess_probe = SessionLevelProbeTrack * sess_probe_hist & 'probe_track_count != probe_hist_count' uuid_bytes = (SessionLevelProbeTrack & oudated_sess_probe.proj()).proj(ub='(session_tracks_plot)').fetch('ub') if len(uuid_bytes): ext_keys = [{'hash': uuid.UUID(bytes=uuid_byte)} for uuid_byte in uuid_bytes] with dj.config(safemode=False): with SessionLevelProbeTrack.connection.transaction: # delete the outdated Probe Tracks (SessionLevelProbeTrack & oudated_sess_probe.proj()).delete() # delete from external store (schema.external['report_store'] & ext_keys).delete(delete_external_files=True) print('{} outdated SessionLevelProbeTrack deleted'.format(len(uuid_bytes))) else: print('All SessionLevelProbeTrack are up-to-date') # ------------- SessionLevelCDReport ---------------- sess_probe = experiment.Session.aggr(ephys.ProbeInsertion, current_probe_count='count(*)') oudated_sess_probe = SessionLevelCDReport * sess_probe & 'cd_probe_count != current_probe_count' uuid_bytes = (SessionLevelCDReport & oudated_sess_probe.proj()).proj(ub='(coding_direction)').fetch('ub') if len(uuid_bytes): ext_keys = [{'hash': uuid.UUID(bytes=uuid_byte)} for uuid_byte in uuid_bytes] with dj.config(safemode=False): with SessionLevelCDReport.connection.transaction: # delete the outdated Probe Tracks (SessionLevelCDReport & oudated_sess_probe.proj()).delete() # delete from external store (schema.external['report_store'] & ext_keys).delete(delete_external_files=True) print('{} outdated SessionLevelCDReport deleted'.format(len(uuid_bytes))) else: print('All SessionLevelCDReport are up-to-date')
def update(self, row, primary_keys={}, classdef=None, **kwargs): if classdef is None: classdef = self.classdef with closing(self.get_conn()) as conn: table = self.create_table(conn, classdef=classdef) try: ret = table.insert1(row, skip_duplicates=False, **kwargs) except DuplicateError: with dj.config(safemode=False): (table & primary_keys).delete() ret = table.insert1(row, skip_duplicates=False, **kwargs) return ret
def update(self, key): print("Populating", key) avi_path = (Eye() & key).get_video_path() tracker = ManualTracker(avi_path) contours = (self.Frame() & key).fetch('contour', order_by='frame_id') tracker.contours = np.array(contours) tracker.contours_detected = np.array([e is not None for e in contours]) tracker.backup_file = '/tmp/tracker_update_state{animal_id}-{session}-{scan_idx}.pkl'.format( **key) try: tracker.run() except Exception as e: print(str(e)) answer = input( 'Tracker crashed. Do you want to save the content anyway [y/n]?' ).lower() while answer not in ['y', 'n']: answer = input( 'Tracker crashed. Do you want to save the content anyway [y/n]?' ).lower() if answer == 'n': raise if input( 'Do you want to delete and replace the existing entries? Type "YES" for acknowledgement.' ) == "YES": with dj.config(safemode=False): with self.connection.transaction: (self & key).delete() logtrace = tracker.mixing_constant.logtrace.astype(float) self.insert1( dict(key, min_lambda=logtrace[logtrace > 0].min())) self.log_key(key) frame = self.Frame() parameters = self.Parameter() for frame_id, ok, contour, params in tqdm( zip(count(), tracker.contours_detected, tracker.contours, tracker.parameter_iter()), total=len(tracker.contours)): assert frame_id == params['frame_id'] if ok: frame.insert1( dict(key, frame_id=frame_id, contour=contour)) else: frame.insert1(dict(key, frame_id=frame_id)) parameters.insert1(dict(key, **params), ignore_extra_fields=True)
def _fix_one_session(key): log.info('Running fix for session: {}'.format(key)) # determine if this session's photostim is: `early-delay` or `late-delay` path = (behavior_ingest.BehaviorIngest.BehaviorFile & key).fetch('behavior_file')[0] h2o = (lab.WaterRestriction & key).fetch1('water_restriction_number') photostim_period = 'early-delay' rig_name = re.search('Recording(Rig\d)_', path) if re.match('SC', h2o) and rig_name: rig_name = rig_name.groups()[0] if rig_name == "Rig3": photostim_period = 'late-delay' invalid_photostim_trials = [] for trial_key in (experiment.PhotostimTrial & key).fetch('KEY'): protocol_type = int( (experiment.TrialNote & trial_key & 'trial_note_type = "protocol #"').fetch1('trial_note')) autolearn = int( (experiment.TrialNote & trial_key & 'trial_note_type = "autolearn"').fetch1('trial_note')) if photostim_period == 'early-delay': valid_protocol = protocol_type == 5 elif photostim_period == 'late-delay': valid_protocol = protocol_type > 4 delay_duration = (experiment.TrialEvent & trial_key & 'trial_event_type = "delay"').fetch( 'duration', order_by='trial_event_time DESC', limit=1)[0] if not (valid_protocol and autolearn == 4 and delay_duration == Decimal('1.2')): # all criteria not met, this trial should not have been a photostim trial invalid_photostim_trials.append(trial_key) log.info('Deleting {} incorrectly labeled PhotostimTrial'.format( len(invalid_photostim_trials))) if len(invalid_photostim_trials): with dj.config(safemode=False): # delete invalid photostim trials (experiment.PhotostimTrial & invalid_photostim_trials).delete() # delete ProbeLevelPhotostimEffectReport figures associated with this session (report.ProbeLevelPhotostimEffectReport & key).delete() return True, invalid_photostim_trials
def delete_empty_ingestion_tables(): from pipeline.ingest import ephys as ephys_ingest from pipeline.ingest import tracking as tracking_ingest from pipeline.ingest import histology as histology_ingest with dj.config(safemode=False): try: (ephys_ingest.EphysIngest & (ephys_ingest.EphysIngest - ephys.ProbeInsertion).fetch('KEY')).delete() (tracking_ingest.TrackingIngest & (tracking_ingest.TrackingIngest - tracking.Tracking).fetch('KEY')).delete() (histology_ingest.HistologyIngest & (histology_ingest.HistologyIngest - histology.ElectrodeCCFPosition).fetch('KEY')).delete() except OperationalError as e: # in case of mysql deadlock - code: 1213 if e.args[0] == 1213: pass
def delete_entries_from_membership(pks_to_be_deleted): ''' Delete entries from shadow membership membership_tables ''' for t in MEMBERSHIP_TABLES: ingest_mod = t['dj_parent_table'].__module__ table_name = t['dj_parent_table'].__name__ mem_table_name = t['dj_current_table'].__name__ print(f'Deleting from table {mem_table_name} ...') real_table = eval(ingest_mod.replace('ibl_pipeline.ingest.', '') + '.' + table_name) with dj.config(safemode=False): (t['dj_current_table'] & (real_table & [{t['dj_parent_uuid_name']:pk} for pk in pks_to_be_deleted if is_valid_uuid(pk)]).fetch('KEY')).delete()
def apply_amplitude_scaling(insertion_keys={}): """ This is a one-time operation only - April 2020 Kilosort2 results from neuropixels probe 2.0 requires an additionally scaling factor of 3.01 applied to the unit amplitude and mean waveform. Future version of quality control pipeline will apply this scaling. """ amp_scale = 3.01 npx2_inserts = ephys.ProbeInsertion & insertion_keys & 'probe_type LIKE "neuropixels 2.0%"' units2fix = ephys.Unit * ephys.ClusteringLabel & npx2_inserts.proj( ) & 'quality_control = 1' units2fix = units2fix - (FixedAmpUnit & 'fixed=1' ) # exclude those that were already fixed if not units2fix: return # safety check, no jrclust results assert len(units2fix & 'clustering_method LIKE "jrclust%"') == 0 fix_hist_key = { 'fix_name': pathlib.Path(__file__).name, 'fix_timestamp': datetime.now() } FixHistory.insert1(fix_hist_key) for unit in tqdm( units2fix.proj('unit_amp', 'waveform').fetch(as_dict=True)): amp = unit.pop('unit_amp') wf = unit.pop('waveform') with dj.conn().transaction: (ephys.Unit & unit)._update('unit_amp', amp * amp_scale) (ephys.Unit & unit)._update('waveform', wf * amp_scale) FixedAmpUnit.insert1({ **fix_hist_key, **unit, 'fixed': True, 'scale': amp_scale }) # delete cluster_quality figures and remake figures with updated unit_amp with dj.config(safemode=False): (report.ProbeLevelReport & npx2_inserts).delete()
def test_manual_insert(self): """ Test whether manual insert of gitkey works with different formats, """ with dj.config(safemode=False): schema.Man().delete() # insert dictionary schema.Man().insert1(dict(idx=0, value=2.)) # positional insert schema.Man().insert1((1, 2.)) # fetch an np.void, modify, ans insert k = schema.Man().fetch.limit(1)()[0] k['idx'] = 2 schema.Man().insert1(k) assert_true(len(schema.Man()) == len(schema.Man.GitKey()) == 3, "Inserting with different datatypes did not work")
def process_cumulative_plots(backtrack_days=30): kwargs = dict(suppress_errors=True, display_progress=True) if mode != 'public': latest = subject.Subject.aggr( behavior_plotting.LatestDate, checking_ts='MAX(checking_ts)') * behavior_plotting.LatestDate & \ [f'latest_date between curdate() - interval {backtrack_days} day and curdate()', (subject.Subject - subject.Death)] & \ (subject.Subject & 'subject_nickname not like "%human%"').proj() else: latest = subject.Subject.aggr( behavior_plotting.LatestDate, checking_ts='MAX(checking_ts)') & \ (subject.Subject & 'subject_nickname not like "%human%"').proj() subj_keys = (subject.Subject & behavior_plotting.CumulativeSummary & latest).fetch('KEY') # delete and repopulate subject by subject with dj.config(safemode=False): for subj_key in tqdm(subj_keys, position=0): (behavior_plotting.CumulativeSummary & subj_key & latest).delete() print('populating...') behavior_plotting.CumulativeSummary.populate( latest & subj_key, **kwargs) # --- update the latest date of the subject ----- # get the latest date of the CumulativeSummary of the subject subj_with_latest_date = (subject.Subject & subj_key).aggr( behavior_plotting.CumulativeSummary, latest_date='max(latest_date)') if len(subj_with_latest_date): new_date = subj_with_latest_date.fetch1('latest_date') current_subj = behavior_plotting.SubjectLatestDate & subj_key if len(current_subj): current_subj._update('latest_date', new_date) else: behavior_plotting.SubjectLatestDate.insert1( subj_with_latest_date.fetch1()) behavior_plotting.CumulativeSummary.populate(**kwargs)
def make(self, key): datasets = (data.FileRecord & key & 'repo_name LIKE "flatiron_%"' & { 'exists': 1 }).fetch('dataset_name') is_complete = bool(np.all([req_ds in datasets for req_ds in self.required_datasets])) \ and bool(np.any(['spikes.times' in d for d in datasets])) if is_complete: self.insert1(key) with dj.config(safemode=False): (EphysMissingDataLog & key).delete_quick() else: for req_ds in self.required_datasets: if req_ds not in datasets: EphysMissingDataLog.insert1(dict(**key, missing_data=req_ds), skip_duplicates=True)
def _make_tuples(self, key): key['eye_time'], frames, key[ 'total_frames'] = self.grab_timestamps_and_frames(key) try: import cv2 print('Drag window and print q when done') rg = CVROIGrabber(frames.mean(axis=2)) rg.grab() except ImportError: rg = ROIGrabber(frames.mean(axis=2)) with dj.config(display__width=50): print(EyeQuality()) key['eye_quality'] = int(input("Enter the quality of the eye: ")) key['eye_roi'] = rg.roi self.insert1(key) print('[Done]') if input('Do you want to stop? y/N: ') == 'y': self.connection.commit_transaction() raise PipelineException('User interrupted population.')
def update(self, key): print("Populating", key) avi_path = (Eye() & key).get_video_path() tracker = ManualTracker(avi_path) contours = (self.Frame() & key).fetch('contour', order_by='frame_id') tracker.contours = np.array(contours) tracker.contours_detected = np.array([e is not None for e in contours]) tracker.backup_file = '/tmp/tracker_update_state{animal_id}-{session}-{scan_idx}.pkl'.format(**key) try: tracker.run() except Exception as e: print(str(e)) answer = input('Tracker crashed. Do you want to save the content anyway [y/n]?').lower() while answer not in ['y', 'n']: answer = input('Tracker crashed. Do you want to save the content anyway [y/n]?').lower() if answer == 'n': raise if input('Do you want to delete and replace the existing entries? Type "YES" for acknowledgement.') == "YES": with dj.config(safemode=False): with self.connection.transaction: (self & key).delete() logtrace = tracker.mixing_constant.logtrace.astype(float) self.insert1(dict(key, min_lambda=logtrace[logtrace > 0].min())) self.log_key(key) frame = self.Frame() parameters = self.Parameter() for frame_id, ok, contour, params in tqdm(zip(count(), tracker.contours_detected, tracker.contours, tracker.parameter_iter()), total=len(tracker.contours)): assert frame_id == params['frame_id'] if ok: frame.insert1(dict(key, frame_id=frame_id, contour=contour)) else: frame.insert1(dict(key, frame_id=frame_id)) parameters.insert1(dict(key, **params), ignore_extra_fields=True)
def archive_electrode_histology(insertion_key, note='', delete=False): """ For the specified "insertion_key" copy from histology.ElectrodeCCFPosition and histology.LabeledProbeTrack (and their respective part tables) to histology.ArchivedElectrodeHistology If "delete" == True - delete the records associated with the "insertion_key" from: + histology.ElectrodeCCFPosition + histology.LabeledProbeTrack + report.ProbeLevelDriftMap + report.ProbeLevelCoronalSlice """ e_ccfs = { d['electrode']: d for d in (histology.ElectrodeCCFPosition.ElectrodePosition & insertion_key).fetch(as_dict=True) } e_error_ccfs = { d['electrode']: d for d in (histology.ElectrodeCCFPosition.ElectrodePositionError & insertion_key).fetch(as_dict=True) } e_ccfs_hash = dict_to_hash({**e_ccfs, **e_error_ccfs}) if histology.ArchivedElectrodeHistology & {'archival_hash': e_ccfs_hash}: if delete: if dj.utils.user_choice( 'The specified ElectrodeCCF has already been archived!\nProceed with delete?' ) != 'yes': return else: print( 'An identical set of the specified ElectrodeCCF has already been archived' ) return archival_time = datetime.now() with histology.ArchivedElectrodeHistology.connection.transaction: histology.ArchivedElectrodeHistology.insert1({ **insertion_key, 'archival_time': archival_time, 'archival_note': note, 'archival_hash': e_ccfs_hash }) histology.ArchivedElectrodeHistology.ElectrodePosition.insert( (histology.ElectrodeCCFPosition.ElectrodePosition & insertion_key).proj(..., archival_time='"{}"'.format(archival_time))) histology.ArchivedElectrodeHistology.ElectrodePositionError.insert( (histology.ElectrodeCCFPosition.ElectrodePositionError & insertion_key).proj(..., archival_time='"{}"'.format(archival_time))) histology.ArchivedElectrodeHistology.LabeledProbeTrack.insert( (histology.LabeledProbeTrack & insertion_key).proj( ..., archival_time='"{}"'.format(archival_time))) histology.ArchivedElectrodeHistology.ProbeTrackPoint.insert( (histology.LabeledProbeTrack.Point & insertion_key).proj( ..., archival_time='"{}"'.format(archival_time))) if delete: with dj.config(safemode=False): (histology.ElectrodeCCFPosition & insertion_key).delete() (histology.LabeledProbeTrack & insertion_key).delete() (report.ProbeLevelDriftMap & insertion_key).delete() (report.ProbeLevelCoronalSlice & insertion_key).delete() (HistologyIngest & insertion_key).delete()
def update_fields(real_schema, shadow_schema, table_name, pks, insert_to_table=False): ''' Given a table and the primary key of real table, update all the fields that have discrepancy. Inputs: real_schema : real schema module, e.g. reference shadow_schema : shadow schema module, e.g. reference_ingest table_name : string, name of a table, e.g. Subject pks : list of dictionaries, primary keys of real table that contains modification insert_to_table : boolean, if True, log the update histolory in the table ibl_update.UpdateRecord ''' real_table = getattr(real_schema, table_name) shadow_table = getattr(shadow_schema, table_name) secondary_fields = set(real_table.heading.secondary_attributes) ts_field = [f for f in secondary_fields if f.endswith('_ts')][0] fields_to_update = secondary_fields - {ts_field} for r in (real_table & pks).fetch('KEY'): pk_hash = UUID(dj.hash.key_hash(r)) if not shadow_table & r: real_record = (real_table & r).fetch1() if insert_to_table: update_record = dict( table=real_table.__module__ + '.' + real_table.__name__, attribute='unknown', pk_hash=pk_hash, original_ts=real_record[ts_field], update_ts=datetime.datetime.now(), pk_dict=r, ) update.UpdateRecord.insert1(update_record) update_record.pop('pk_dict') update_error_msg = 'Record does not exist in the shadow {}'.format(r) update_record_error = dict( **update_record, update_action_ts=datetime.datetime.now(), update_error_msg=update_error_msg ) update.UpdateError.insert1(update_record_error) print(update_error_msg) continue # if there are more than 1 record elif len(shadow_table & r) > 1: # delete the older record ts_field = [f for f in shadow_table.heading.names if '_ts' in f][0] lastest_record = dj.U().aggr(shadow_table & r, session_ts='max(session_ts)').fetch() with dj.config(safemode=False): ((shadow_table & r) - lastest_record).delete() shadow_record = (shadow_table & r).fetch1() real_record = (real_table & r).fetch1() for f in fields_to_update: if real_record[f] != shadow_record[f]: try: (real_table & r)._update(f, shadow_record[f]) update_narrative = f'{table_name}.{f}: {shadow_record[f]} != {real_record[f]}' print(update_narrative) if insert_to_table: update_record = dict( table=real_table.__module__ + '.' + real_table.__name__, attribute=f, pk_hash=pk_hash, original_ts=real_record[ts_field], update_ts=shadow_record[ts_field], pk_dict=r, original_value=real_record[f], updated_value=shadow_record[f], update_narrative=update_narrative ) update.UpdateRecord.insert1(update_record) except BaseException as e: print(f'Error while updating record {r}: {str(e)}')
import datajoint as dj from ibl_pipeline import subject, acquisition, data, behavior from ibl_pipeline.ingest import acquisition as acquisition_ingest from ibl_pipeline.ingest import data as data_ingest from ibl_pipeline.ingest import alyxraw import datetime from oneibl.one import ONE import numpy as np from uuid import UUID if __name__ == '__main__': with dj.config(safemode=False): uuids = ((acquisition_ingest.Session - behavior.TrialSet.proj()) & 'session_start_time > "2019-06-13"').fetch('session_uuid') uuid_str = [str(uuid) for uuid in uuids] for uuid in uuid_str: keys = (alyxraw.AlyxRaw.Field & 'fvalue="{}"'.format(uuid)).fetch('KEY') (alyxraw.AlyxRaw & keys).delete() (alyxraw.AlyxRaw & {'uuid': UUID(uuid)}).delete() if len(acquisition_ingest.Session & {'session_uuid': UUID(uuid)}): subj_uuid, session_start_time = (acquisition_ingest.Session & { 'session_uuid': UUID(uuid) }).fetch1('subject_uuid', 'session_start_time') else: continue key = { 'subject_uuid': subj_uuid,
def test_preview(): with dj.config(display__limit=7): x = A().proj(a='id_a') s = x.preview() assert_equal(len(s.split('\n')), len(x) + 2)
def fix_session(session_key): paths = behavior_ingest.RigDataPath.fetch(as_dict=True) files = (behavior_ingest.BehaviorIngest * behavior_ingest.BehaviorIngest.BehaviorFile & session_key).fetch(as_dict=True, order_by='behavior_file asc') filelist = [] for pf in [(p, f) for f in files for p in paths]: p, f = pf found = find_path(p['rig_data_path'], f['behavior_file']) if found: filelist.append(found) if len(filelist) != len(files): log.warning("behavior files missing in {} ({}/{}). skipping".format( session_key, len(filelist), len(files))) return False log.info('filelist: {}'.format(filelist)) # # Prepare PhotoStim # photosti_duration = 0.5 # (s) Hard-coded here photostims = { 4: { 'photo_stim': 4, 'photostim_device': 'OBIS470', 'brain_location_name': 'left_alm', 'duration': photosti_duration }, 5: { 'photo_stim': 5, 'photostim_device': 'OBIS470', 'brain_location_name': 'right_alm', 'duration': photosti_duration }, 6: { 'photo_stim': 6, 'photostim_device': 'OBIS470', 'brain_location_name': 'both_alm', 'duration': photosti_duration } } # # Extract trial data from file(s) & prepare trial loop # trials = zip() trial = namedtuple( # simple structure to track per-trial vars 'trial', ('ttype', 'stim', 'settings', 'state_times', 'state_names', 'state_data', 'event_data', 'event_times')) for f in filelist: if os.stat(f).st_size / 1024 < 1000: log.info('skipping file {f} - too small'.format(f=f)) continue log.debug('loading file {}'.format(f)) mat = spio.loadmat(f, squeeze_me=True) SessionData = mat['SessionData'].flatten() AllTrialTypes = SessionData['TrialTypes'][0] AllTrialSettings = SessionData['TrialSettings'][0] RawData = SessionData['RawData'][0].flatten() AllStateNames = RawData['OriginalStateNamesByNumber'][0] AllStateData = RawData['OriginalStateData'][0] AllEventData = RawData['OriginalEventData'][0] AllStateTimestamps = RawData['OriginalStateTimestamps'][0] AllEventTimestamps = RawData['OriginalEventTimestamps'][0] # verify trial-related data arrays are all same length assert (all( (x.shape[0] == AllStateTimestamps.shape[0] for x in (AllTrialTypes, AllTrialSettings, AllStateNames, AllStateData, AllEventData, AllEventTimestamps)))) if 'StimTrials' in SessionData.dtype.fields: log.debug('StimTrials detected in session - will include') AllStimTrials = SessionData['StimTrials'][0] assert (AllStimTrials.shape[0] == AllStateTimestamps.shape[0]) else: log.debug('StimTrials not detected in session - will skip') AllStimTrials = np.array( [None for i in enumerate(range(AllStateTimestamps.shape[0]))]) z = zip(AllTrialTypes, AllStimTrials, AllTrialSettings, AllStateTimestamps, AllStateNames, AllStateData, AllEventData, AllEventTimestamps) trials = chain(trials, z) # concatenate the files trials = list(trials) # all files were internally invalid or size < 100k if not trials: log.warning('skipping ., no valid files') return False key = session_key skey = (experiment.Session & key).fetch1() # # Actually load the per-trial data # log.info('BehaviorIngest.make(): trial parsing phase') # lists of various records for batch-insert rows = { k: list() for k in ('trial', 'behavior_trial', 'trial_note', 'trial_event', 'corrected_trial_event', 'action_event', 'photostim', 'photostim_location', 'photostim_trial', 'photostim_trial_event') } i = -1 for t in trials: # # Misc # t = trial(*t) # convert list of items to a 'trial' structure i += 1 # increment trial counter log.debug('BehaviorIngest.make(): parsing trial {i}'.format(i=i)) # covert state data names into a lookup dictionary # # names (seem to be? are?): # # Trigtrialstart # PreSamplePeriod # SamplePeriod # DelayPeriod # EarlyLickDelay # EarlyLickSample # ResponseCue # GiveLeftDrop # GiveRightDrop # GiveLeftDropShort # GiveRightDropShort # AnswerPeriod # Reward # RewardConsumption # NoResponse # TimeOut # StopLicking # StopLickingReturn # TrialEnd states = {k: (v + 1) for v, k in enumerate(t.state_names)} required_states = ('PreSamplePeriod', 'SamplePeriod', 'DelayPeriod', 'ResponseCue', 'StopLicking', 'TrialEnd') missing = list(k for k in required_states if k not in states) if len(missing): log.warning('skipping trial {i}; missing {m}'.format(i=i, m=missing)) continue gui = t.settings['GUI'].flatten() # ProtocolType - only ingest protocol >= 3 # # 1 Water-Valve-Calibration 2 Licking 3 Autoassist # 4 No autoassist 5 DelayEnforce 6 SampleEnforce 7 Fixed # if 'ProtocolType' not in gui.dtype.names: log.warning('skipping trial {i}; protocol undefined'.format(i=i)) continue protocol_type = gui['ProtocolType'][0] if gui['ProtocolType'][0] < 3: log.warning('skipping trial {i}; protocol {n} < 3'.format( i=i, n=gui['ProtocolType'][0])) continue # # Top-level 'Trial' record # tkey = dict(skey) startindex = np.where(t.state_data == states['PreSamplePeriod'])[0] # should be only end of 1st StopLicking; # rest of data is irrelevant w/r/t separately ingested ephys endindex = np.where(t.state_data == states['StopLicking'])[0] log.debug('states\n' + str(states)) log.debug('state_data\n' + str(t.state_data)) log.debug('startindex\n' + str(startindex)) log.debug('endindex\n' + str(endindex)) if not (len(startindex) and len(endindex)): log.warning( 'skipping trial {i}: start/end index error: {s}/{e}'.format( i=i, s=str(startindex), e=str(endindex))) continue try: tkey['trial'] = i tkey[ 'trial_uid'] = i # Arseny has unique id to identify some trials tkey['start_time'] = t.state_times[startindex][0] tkey['stop_time'] = t.state_times[endindex][0] except IndexError: log.warning( 'skipping trial {i}: error indexing {s}/{e} into {t}'.format( i=i, s=str(startindex), e=str(endindex), t=str(t.state_times))) continue log.debug('BehaviorIngest.make(): Trial().insert1') # TODO msg log.debug('tkey' + str(tkey)) rows['trial'].append(tkey) # # Specific BehaviorTrial information for this trial # bkey = dict(tkey) bkey['task'] = 'audio delay' # hard-coded here bkey['task_protocol'] = 1 # hard-coded here # determine trial instruction trial_instruction = 'left' # hard-coded here if gui['Reversal'][0] == 1: if t.ttype == 1: trial_instruction = 'left' elif t.ttype == 0: trial_instruction = 'right' elif gui['Reversal'][0] == 2: if t.ttype == 1: trial_instruction = 'right' elif t.ttype == 0: trial_instruction = 'left' bkey['trial_instruction'] = trial_instruction # determine early lick early_lick = 'no early' if (protocol_type >= 5 and 'EarlyLickDelay' in states and np.any(t.state_data == states['EarlyLickDelay'])): early_lick = 'early' if (protocol_type > 5 and ('EarlyLickSample' in states and np.any(t.state_data == states['EarlyLickSample']))): early_lick = 'early' bkey['early_lick'] = early_lick # determine outcome outcome = 'ignore' if ('Reward' in states and np.any(t.state_data == states['Reward'])): outcome = 'hit' elif ('TimeOut' in states and np.any(t.state_data == states['TimeOut'])): outcome = 'miss' elif ('NoResponse' in states and np.any(t.state_data == states['NoResponse'])): outcome = 'ignore' bkey['outcome'] = outcome rows['behavior_trial'].append(bkey) # # Add 'protocol' note # nkey = dict(tkey) nkey['trial_note_type'] = 'protocol #' nkey['trial_note'] = str(protocol_type) rows['trial_note'].append(nkey) # # Add 'autolearn' note # nkey = dict(tkey) nkey['trial_note_type'] = 'autolearn' nkey['trial_note'] = str(gui['Autolearn'][0]) rows['trial_note'].append(nkey) # # Add 'bitcode' note # if 'randomID' in gui.dtype.names: nkey = dict(tkey) nkey['trial_note_type'] = 'bitcode' nkey['trial_note'] = str(gui['randomID'][0]) rows['trial_note'].append(nkey) # # Add presample event # log.debug('BehaviorIngest.make(): presample') ekey = dict(tkey) sampleindex = np.where(t.state_data == states['SamplePeriod'])[0] ekey['trial_event_id'] = len(rows['trial_event']) ekey['trial_event_type'] = 'presample' ekey['trial_event_time'] = t.state_times[startindex][0] ekey['duration'] = (t.state_times[sampleindex[0]] - t.state_times[startindex])[0] if math.isnan(ekey['duration']): log.debug('BehaviorIngest.make(): fixing presample duration') ekey['duration'] = 0.0 # FIXDUR: lookup from previous trial rows['trial_event'].append(ekey) # # Add other 'sample' events # log.debug('BehaviorIngest.make(): sample events') last_dur = None for s in sampleindex: # in protocol > 6 ~-> n>1 # todo: batch events ekey = dict(tkey) ekey['trial_event_id'] = len(rows['trial_event']) ekey['trial_event_type'] = 'sample' ekey['trial_event_time'] = t.state_times[s] ekey['duration'] = gui['SamplePeriod'][0] if math.isnan(ekey['duration']) and last_dur is None: log.warning('... trial {} bad duration, no last_edur'.format( i, last_dur)) ekey['duration'] = 0.0 # FIXDUR: cross-trial check rows['corrected_trial_event'].append(ekey) elif math.isnan(ekey['duration']) and last_dur is not None: log.warning('... trial {} duration using last_edur {}'.format( i, last_dur)) ekey['duration'] = last_dur rows['corrected_trial_event'].append(ekey) else: last_dur = ekey['duration'] # only track 'good' values. rows['trial_event'].append(ekey) # # Add 'delay' events # log.debug('BehaviorIngest.make(): delay events') last_dur = None delayindex = np.where(t.state_data == states['DelayPeriod'])[0] for d in delayindex: # protocol > 6 ~-> n>1 ekey = dict(tkey) ekey['trial_event_id'] = len(rows['trial_event']) ekey['trial_event_type'] = 'delay' ekey['trial_event_time'] = t.state_times[d] ekey['duration'] = gui['DelayPeriod'][0] if math.isnan(ekey['duration']) and last_dur is None: log.warning('... {} bad duration, no last_edur'.format( i, last_dur)) ekey['duration'] = 0.0 # FIXDUR: cross-trial check rows['corrected_trial_event'].append(ekey) elif math.isnan(ekey['duration']) and last_dur is not None: log.warning('... {} duration using last_edur {}'.format( i, last_dur)) ekey['duration'] = last_dur rows['corrected_trial_event'].append(ekey) else: last_dur = ekey['duration'] # only track 'good' values. log.debug('delay event duration: {}'.format(ekey['duration'])) rows['trial_event'].append(ekey) # # Add 'go' event # log.debug('BehaviorIngest.make(): go') ekey = dict(tkey) responseindex = np.where(t.state_data == states['ResponseCue'])[0] ekey['trial_event_id'] = len(rows['trial_event']) ekey['trial_event_type'] = 'go' ekey['trial_event_time'] = t.state_times[responseindex][0] ekey['duration'] = gui['AnswerPeriod'][0] if math.isnan(ekey['duration']): log.debug('BehaviorIngest.make(): fixing go duration') ekey['duration'] = 0.0 # FIXDUR: lookup from previous trials rows['corrected_trial_event'].append(ekey) rows['trial_event'].append(ekey) # # Add 'trialEnd' events # log.debug('BehaviorIngest.make(): trialend events') last_dur = None trialendindex = np.where(t.state_data == states['TrialEnd'])[0] ekey = dict(tkey) ekey['trial_event_id'] = len(rows['trial_event']) ekey['trial_event_type'] = 'trialend' ekey['trial_event_time'] = t.state_times[trialendindex][0] ekey['duration'] = 0.0 rows['trial_event'].append(ekey) # # Add lick events # lickleft = np.where(t.event_data == 69)[0] log.debug('... lickleft: {r}'.format(r=str(lickleft))) action_event_count = len(rows['action_event']) if len(lickleft): [ rows['action_event'].append( dict(tkey, action_event_id=action_event_count + idx, action_event_type='left lick', action_event_time=t.event_times[l])) for idx, l in enumerate(lickleft) ] lickright = np.where(t.event_data == 71)[0] log.debug('... lickright: {r}'.format(r=str(lickright))) action_event_count = len(rows['action_event']) if len(lickright): [ rows['action_event'].append( dict(tkey, action_event_id=action_event_count + idx, action_event_type='right lick', action_event_time=t.event_times[r])) for idx, r in enumerate(lickright) ] # Photostim Events # # TODO: # # - base stimulation parameters: # # - should be loaded elsewhere - where # - actual ccf locations - cannot be known apriori apparently? # - Photostim.Profile: what is? fix/add # # - stim data # # - how retrieve power from file (didn't see) or should # be statically coded here? # - how encode stim type 6? # - we have hemisphere as boolean or # - but adding an event 4 and event 5 means querying # is less straightforwrard (e.g. sessions with 5 & 6) if t.stim: log.debug('BehaviorIngest.make(): t.stim == {}'.format(t.stim)) rows['photostim_trial'].append(tkey) delay_period_idx = np.where( t.state_data == states['DelayPeriod'])[0][0] rows['photostim_trial_event'].append( dict(tkey, **photostims[t.stim], photostim_event_id=len(rows['photostim_trial_event']), photostim_event_time=t.state_times[delay_period_idx], power=5.5)) # end of trial loop. log.info('BehaviorIngest.make(): ... experiment.TrialEvent') fix_events = rows['trial_event'] ref_events = (experiment.TrialEvent() & skey).fetch( order_by='trial, trial_event_id', as_dict=True) if False: for e in ref_events: log.debug('ref_events: t: {}, e: {}, event_type: {}'.format( e['trial'], e['trial_event_id'], e['trial_event_type'])) for e in fix_events: log.debug('fix_events: t: {}, e: {}, type: {}'.format( e['trial'], e['trial_event_id'], e['trial_event_type'])) log.info('deleting old events') with dj.config(safemode=False): log.info('... TrialEvent') (experiment.TrialEvent() & session_key).delete() log.info('... CorrectedTrialEvents') (behavior_ingest.BehaviorIngest.CorrectedTrialEvents() & session_key).delete_quick() log.info('adding new records') log.info('... experiment.TrialEvent') experiment.TrialEvent().insert(rows['trial_event'], ignore_extra_fields=True, allow_direct_insert=True, skip_duplicates=True) log.info('... CorrectedTrialEvents') behavior_ingest.BehaviorIngest.CorrectedTrialEvents().insert( rows['corrected_trial_event'], ignore_extra_fields=True, allow_direct_insert=True) return True
schema = dj.Schema(PREFIX + '_fetch_same', connection=dj.conn(**CONN_INFO)) @schema class ProjData(dj.Manual): definition = """ id : int --- resp : float sim : float big : longblob blah : varchar(10) """ with dj.config(enable_python_native_blobs=True): ProjData().insert([{ 'id': 0, 'resp': 20.33, 'sim': 45.324, 'big': 3, 'blah': 'yes' }, { 'id': 1, 'resp': 94.3, 'sim': 34.23, 'big': { 'key1': np.random.randn(20, 10) }, 'blah': 'si' }, {
def test_suppress_dj_errors(): """ test_suppress_dj_errors: dj errors suppressible w/o native py blobs """ schema.schema.jobs.delete() with dj.config(enable_python_native_blobs=False): schema.ErrorClass.populate(reserve_jobs=True, suppress_errors=True) assert_true(len(schema.DjExceptionName()) == len(schema.schema.jobs) > 0)
def _add_spike_sites_and_depths(data, probe, npx_meta, rigpath): sinfo = data['sinfo'] skey = data['skey'] method = data['method'] spikes = data['spikes'] units = data['units'] spike_sites = data['spike_sites'] spike_depths = data['spike_depths'] creation_time = data['creation_time'] clustering_label = data['clustering_label'] log.info( '-- Start insertions for probe: {} - Clustering method: {} - Label: {}' .format(probe, method, clustering_label)) # probe insertion key insertion_key, e_config_key = _gen_probe_insert( sinfo, probe, npx_meta, probe_insertion_exists=True) # remove noise clusters if method in ['jrclust_v3', 'jrclust_v4']: units, spikes, spike_sites, spike_depths = (v[i] for v, i in zip(( units, spikes, spike_sites, spike_depths), repeat((units > 0)))) q_electrodes = lab.ProbeType.Electrode * lab.ElectrodeConfig.Electrode & e_config_key site2electrode_map = {} for recorded_site in np.unique(spike_sites): shank, shank_col, shank_row, _ = npx_meta.shankmap['data'][ recorded_site - 1] # subtract 1 because npx_meta shankmap is 0-indexed site2electrode_map[recorded_site] = ( q_electrodes & { 'shank': shank + 1, # this is a 1-indexed pipeline 'shank_col': shank_col + 1, 'shank_row': shank_row + 1 }).fetch1('KEY') spike_sites = np.array( [site2electrode_map[s]['electrode'] for s in spike_sites]) unit_spike_sites = np.array( [spike_sites[np.where(units == u)] for u in set(units)]) unit_spike_depths = np.array( [spike_depths[np.where(units == u)] for u in set(units)]) archive_key = { **skey, 'insertion_number': probe, 'clustering_method': method, 'clustering_time': creation_time } # delete and reinsert with spike_sites and spike_depths for i, u in enumerate(set(units)): ukey = {**archive_key, 'unit': u} unit_data = (ephys.ArchivedClustering.Unit.proj( ..., '-spike_sites', '-spike_depths') & ukey).fetch1() with dj.config(safemode=False): (ephys.ArchivedClustering.Unit & ukey).delete() ephys.ArchivedClustering.Unit.insert1({ **unit_data, 'spike_sites': unit_spike_sites[i], 'spike_depths': unit_spike_depths[i] })