def __init__(self, protocol=False): self.setup, self.is_pi = socket.gethostname(), os.uname( )[4][:3] == 'arm' self.curr_state, self.lock, self.queue, self.curr_trial, self.total_reward = '', False, PriorityQueue( ), 0, 0 self.session_key = dict() self.ping_timer, self.logger_timer = Timer(), Timer() self.setup_status = 'running' if protocol else 'ready' self.log_setup(protocol) fileobject = open( os.path.dirname(os.path.abspath(__file__)) + '/dj_local_conf.json') connect_info = json.loads(fileobject.read()) background_conn = dj.Connection(connect_info['database.host'], connect_info['database.user'], connect_info['database.password']) self.schemata = dict() self.schemata['lab'] = dj.create_virtual_module( 'beh.py', 'lab_behavior', connection=background_conn) self.schemata['mice'] = dj.create_virtual_module( 'mice.py', 'lab_mice', connection=background_conn) self.thread_end, self.thread_lock = threading.Event(), threading.Lock() self.inserter_thread = threading.Thread(target=self.inserter) self.getter_thread = threading.Thread(target=self.getter) self.inserter_thread.start() self.getter_thread.start() self.logger_timer.start() # start session time
def load_traces_and_frametimes(self, key): # -- find number of recording depths pipe = (fuse.Activity() & key).fetch('pipe') assert len( np.unique(pipe)) == 1, 'Selection is from different pipelines' pipe = dj.create_virtual_module(pipe[0], 'pipeline_' + pipe[0]) k = dict(key) k.pop('field', None) ndepth = len(dj.U('z') & (pipe.ScanInfo.Field() & k)) frame_times = (stimulus.Sync() & key).fetch1('frame_times').squeeze()[::ndepth] soma = pipe.MaskClassification.Type() & dict(type='soma') spikes = (dj.U('field', 'channel') * pipe.Activity.Trace() * StaticScan.Unit() \ * pipe.ScanSet.UnitInfo() & soma & key) traces, ms_delay, trace_keys = spikes.fetch( 'trace', 'ms_delay', dj.key, order_by='animal_id, session, scan_idx, unit_id') delay = np.fromiter(ms_delay / 1000, dtype=np.float) frame_times = (delay[:, None] + frame_times[None, :]) traces = np.vstack( [fill_nans(tr.astype(np.float32)).squeeze() for tr in traces]) traces, frame_times = self.adjust_trace_len(traces, frame_times) return traces, frame_times, trace_keys
def make(self, key): self.insert(fuse.ScanDone() & key, ignore_extra_fields=True) pipe = (fuse.ScanDone() & key).fetch1('pipe') pipe = dj.create_virtual_module(pipe, 'pipeline_' + pipe) self.Unit().insert(fuse.ScanDone * pipe.ScanSet.Unit * pipe.MaskClassification.Type & key & dict(pipe_version=1, segmentation_method=6, spike_method=5, type='soma'), ignore_extra_fields=True)
def get_network_path(path_name): """ Get network root path depnding on os and path required Args: path_name (str): One of the main paths for data storage (Bezos, braininit, u19_dj) Returns: network_path: (str): String with network path as mounted by the corresponding os """ key = dict() # Check if path name to search starts with needed / at start if path_name[0] != '/': key['global_path'] = '/' + path_name else: key['global_path'] = path_name field_get = ['local_path'] if is_this_spock(): field_get = ['bucket_path'] key['system'] = 'linux' elif sys.platform == "darwin": key['system'] = 'mac' elif os.name == 'nt': key['system'] = 'windows' else: key['system'] = 'linux' lab = dj.create_virtual_module( 'lab', dj.config['custom']['database.prefix'] + 'lab') network_path = (lab.Path & key).fetch1(*field_get) return network_path
def __init__(self, protocol=False): self.setup = socket.gethostname() self.is_pi = os.uname( )[4][:3] == 'arm' if os.name == 'posix' else False self.setup_status = 'running' if protocol else 'ready' fileobject = open( os.path.dirname(os.path.abspath(__file__)) + '/../dj_local_conf.json') con_info = json.loads(fileobject.read()) self.private_conn = dj.Connection(con_info['database.host'], con_info['database.user'], con_info['database.password']) for schema, value in schemata.items( ): # separate connection for internal comminication self._schemata.update({ schema: dj.create_virtual_module(schema, value, connection=self.private_conn) }) self.thread_end, self.thread_lock = threading.Event(), threading.Lock() self.inserter_thread = threading.Thread(target=self.inserter) self.getter_thread = threading.Thread(target=self.getter) self.inserter_thread.start() self.log_setup(protocol) self.getter_thread.start() self.logger_timer.start()
def depstick(sname, direction='reverse'): ''' check/print report of dependencies ''' vm = dj.create_virtual_module(sname, sname) dbc = vm.schema.connection # Constraint CONSTRAINT_NAME in CONSTRAINT_SCHEMA TABLE_NAME refers to # table REFERENCED_TABLE_NAME in UNIQUE_CONSTRAINT_SCHEMA. if direction == 'forward': q = ''' SELECT distinct(UNIQUE_CONSTRAINT_SCHEMA) FROM information_schema.REFERENTIAL_CONSTRAINTS where constraint_schema='{}'; '''.format(sname) elif direction == 'reverse': q = ''' SELECT distinct(CONSTRAINT_SCHEMA) FROM information_schema.REFERENTIAL_CONSTRAINTS where unique_constraint_schema='{}'; '''.format(sname) else: raise Exception("depstick doesn't know {} direction." .format(direction)) print('-- {} {} dependencies'.format(sname, direction)) for r in dbc.query(q): print(r[0]) if r[0] != sname else None
def get_table_attributes(jwt_payload: dict, schema_name: str, table_name: str): """ Method to get primary and secondary attributes of a table :param jwt_payload: Dictionary containing databaseAddress, username and password strings :type jwt_payload: dict :param schema_name: Name of schema to list all tables from :type schema_name: str :param table_name: Table name under the given schema; must be in camel case :type table_name: str :return: Dict of primary, secondary attributes and with metadata: attribute_name, type, nullable, default, autoincrement. :rtype: dict """ DJConnector.set_datajoint_config(jwt_payload) schema_virtual_module = dj.create_virtual_module( schema_name, schema_name) table_attributes = dict(primary_attributes=[], secondary_attributes=[]) for attribute_name, attribute_info in getattr( schema_virtual_module, table_name).heading.attributes.items(): if attribute_info.in_key: table_attributes['primary_attributes'].append( (attribute_name, attribute_info.type, attribute_info.nullable, attribute_info.default, attribute_info.autoincrement)) else: table_attributes['secondary_attributes'].append( (attribute_name, attribute_info.type, attribute_info.nullable, attribute_info.default, attribute_info.autoincrement)) return table_attributes
def _make_tuples(self, key): # we are not making this a direct dependency because that database is not Datajoint compatible animals = dj.create_virtual_module('census', 'animal_keeping') a = scan_info(key['cell_id'], basedir=BASEDIR) a = a['Subject'] if 'Subject' in a else a['Recording']['Subject'] if not animals.CensusSubject() & dict(name=a['Identifier']): if 'zucht' in a['Identifier']: tmp = a['Identifier'].split('leptozucht') tmp[-1] = '{:04d}'.format(int(tmp[-1])) a['Identifier'] = 'leptozucht'.join(tmp) else: tmp = a['Identifier'].split('lepto') tmp[-1] = '{:04d}'.format(int(tmp[-1])) a['Identifier'] = 'lepto'.join(tmp) fish_id = (animals.CensusSubject() & dict(name=a['Identifier'])).fetch1('id') self.insert1( dict(key, fish_id=fish_id, eod_frequency=float(a['EOD Frequency'][:-2]), gender=a['Gender'].lower(), weight=float(a['Weight'][:-1]), size=float(a['Size'][:-2])))
def make_vmods(tag, cfg, connection): return { k: dj.create_virtual_module('{}_{}'.format(tag, k), v, connection=connection) for k, v in cfg.items() }
def wheel_metrics(trials, session): session_duration = pd.DataFrame(trials.groupby( ['subject_uuid', 'session_start_time'])['trial_start_time'].max()) session_duration = session_duration.reset_index(level=[0, 1]) session_duration = session_duration.rename(columns={'trial_start_time': "session_duration"}) # Data with added column for session duration data2 = trials.merge(session_duration, on=['subject_uuid', 'session_start_time']) # --Wheel data dj_wheel = dj.create_virtual_module('wheel_moves', 'group_shared_wheel') movements_summary = pd.DataFrame.from_dict(dj_wheel.WheelMoveSet().fetch(as_dict=True)) movements_summary = movements_summary.merge(data2, on=['subject_uuid', 'session_start_time']) movements_summary['disp_norm'] = np.abs(movements_summary['total_displacement'] / movements_summary['total_distance']) movements_summary['moves_time'] = (movements_summary['total_distance'] / movements_summary['session_duration']) # -- df # data should have same mice as data2, but movements_summary might have less mice = movements_summary['subject_uuid'].unique() df = pd.DataFrame(columns=['subject_uuid', 'disp_norm', 'moves_time'], index=range(len(mice))) for m, mouse in enumerate(mice): mouse_data = movements_summary.loc[movements_summary.subject_uuid == mouse] df['subject_uuid'][m] = mouse df['disp_norm'][m] = np.nanmean(mouse_data.loc[mouse_data['training_day'] <= session, 'disp_norm']) df['moves_time'][m] = np.nanmean(mouse_data.loc[mouse_data['training_day'] <= session, 'moves_time']) return df
def get_virtual_module(full_table_name, context=None): if not context: context = inspect.currentframe().f_back.f_locals schema_name = re.match('`(.*)`\.', full_table_name).group(1) vmod = dj.create_virtual_module(schema_name, schema_name) context[vmod.__name__] = vmod return vmod.__name__
def load_frame_times(self, key): pipe = (fuse.Activity() & key).fetch('pipe') assert len(np.unique(pipe)) == 1, 'Selection is from different pipelines' pipe = dj.create_virtual_module(pipe[0], 'pipeline_' + pipe[0]) k = dict(key) k.pop('field', None) ndepth = len(dj.U('z') & (pipe.ScanInfo.Field() & k)) return (stimulus.Sync() & key).fetch1('frame_times').squeeze()[::ndepth]
def erd(*args): report = dj.create_virtual_module('report', get_schema_name('report')) mods = (ephys, lab, experiment, tracking, psth, ccf, histology, report, publication) for mod in mods: modname = str().join(mod.__name__.split('.')[1:]) fname = os.path.join('images', '{}.png'.format(modname)) print('saving', fname) dj.ERD(mod, context={modname: mod}).save(fname)
def datajoint_dot(): from bokeh.models import Div from bokeh.layouts import layout from bokeh.models.widgets import Panel subject = dj.create_virtual_module('subject', 'u19_subject') action = dj.create_virtual_module('action', 'u19_action') acquisition = dj.create_virtual_module('acquisition', 'u19_acquisition') try: svg = (dj.Diagram(subject) + dj.Diagram(action) + dj.Diagram(acquisition)).make_dot().create_svg() div = Div(text='<object data={0}'.format(svg.decode('utf-8'))) # for some reason div can handle incomplete tags, completing is has artifact. except: print( 'Could not get diagram, did you install graphviz and pydotplus??') div = Div(text='installation not complete') return Panel(child=layout([div]), title='Overview')
def load_behavior_timing(self, key): log.info('Loading behavior frametimes') # -- find number of recording depths pipe = (fuse.Activity() & key).fetch('pipe') assert len(np.unique(pipe)) == 1, 'Selection is from different pipelines' pipe = dj.create_virtual_module(pipe[0], 'pipeline_' + pipe[0]) k = dict(key) k.pop('field', None) ndepth = len(dj.U('z') & (pipe.ScanInfo.Field() & k)) return (stimulus.BehaviorSync() & key).fetch1('frame_times').squeeze()[0::ndepth]
def visualized_3d(results, matchings, base_layout, eID, probe): ephys = dj.create_virtual_module('ephys', 'ibl_ephys') histology = dj.create_virtual_module('histology', 'ibl_histology') key_session = [{'session_uuid': UUID(eID)}] key = (acquisition.Session & ephys.DefaultCluster & key_session).fetch( 'KEY', limit=1) if probe != 'both': key[0]['probe_idx'] = probe clusters = (ephys.DefaultCluster & key & histology.ChannelBrainLocationTemp).fetch('KEY') cluster_coords = [] for key in tqdm(clusters): channel_raw_inds, channel_local_coordinates = \ (ephys.ChannelGroup & key).fetch1( 'channel_raw_inds', 'channel_local_coordinates') channel = (ephys.DefaultCluster & key).fetch1('cluster_channel') if channel in channel_raw_inds: channel_coords = (np.squeeze( channel_local_coordinates[channel_raw_inds == channel])) # get the Location with highest provenance q = histology.ChannelBrainLocationTemp & key & \ dict(channel_lateral=channel_coords[0], channel_axial=channel_coords[1]) & 'provenance=70' if q: cluster_coords.append( q.fetch('channel_x', 'channel_y', 'channel_z')) else: cluster_coords.append([0, 0, 0]) else: cluster_coords.append([0, 0, 0]) matchings.insert(base_layout, {}) allegiance = belong_dictionary([i["partition"] for i in results], matchings, base_layout) location = results[0]["locations"] clusters = len(cluster_coords) final = [] for i in range(clusters): final.append(cluster_coords[i] + [location[i]] + allegiance[i]) return final
def configure_minnie(return_virtual_module=False, create_if_missing=False, host=None, cache_path=None): verify_paths(create_if_missing=create_if_missing) set_configurations(host=host, cache_path=cache_path) if return_virtual_module: import datajoint as dj return dj.create_virtual_module('minnie', schema_name_m65, add_objects=adapter_objects)
def surgery_notification(): # Sends notifications to specified slack channel, surgeon, and lab manager about any checkups that need to be done num_to_word = { 1: 'one', 2: 'two', 3: 'three' } # Used to figure out which column to look up for checkup date # Define all Slack notification variables slack_notification_channel = "#surgery_reminders" slack_manager = "camila" slacktable = dj.create_virtual_module('pipeline_notification', 'pipeline_notification') domain, api_key = slacktable.SlackConnection.fetch1('domain', 'api_key') slack = Slacker(api_key, timeout=60) # Only fetch surgeries done 1 to 3 days ago lessthan_date_res = (datetime.today()).strftime("%Y-%m-%d") greaterthan_date_res = (datetime.today() - timedelta(days=4)).strftime("%Y-%m-%d") restriction = 'surgery_outcome = "Survival" and date < "{}" and date > "{}"'.format( lessthan_date_res, greaterthan_date_res) surgery_data = (experiment.Surgery & restriction).fetch(order_by='date DESC') for entry in surgery_data: status = (experiment.SurgeryStatus & entry).fetch(order_by="timestamp DESC")[0] day_key = "day_" + num_to_word[(datetime.today().date() - entry['date']).days] edit_url = "<{}|Update Status Here>".format( url_for('main.surgery_update', _external=True, animal_id=entry['animal_id'], surgery_id=entry['surgery_id'])) if status['euthanized'] == 0 and status[day_key] == 0: manager_message = "{} needs to check animal {} in room {} for surgery on {}. {}".format( entry['username'].title(), entry['animal_id'], entry['mouse_room'], entry['date'], edit_url) ch_message = "<!channel> Reminder: " + manager_message slack.chat.post_message("@" + slack_manager, manager_message) slack.chat.post_message(slack_notification_channel, ch_message) if len(slacktable.SlackUser & entry) > 0: slackname = (slacktable.SlackUser & entry).fetch('slack_user') pm_message = "Don't forget to check on animal {} today! {}".format( entry['animal_id'], edit_url) slack.chat.post_message("@" + slackname, pm_message, as_user=True) return '', http.HTTPStatus.NO_CONTENT
def connect_to_datajoint(self): if self.is_connected: return True for key, value in self.config.items(): dj.config[key] = value self.connection = dj.conn() self.is_connected = self.connection.is_connected if self.is_connected: for schema in dj.list_schemas(): setattr(self, schema, dj.create_virtual_module(f'{schema}.py', schema)) return self.is_connected
def delete_tuple(jwt_payload: dict, schema_name: str, table_name: str, tuple_to_restrict_by: dict, cascade: bool = False): """ Delete a specific record based on the restriction given (Can only delete 1 at a time) :param jwt_payload: Dictionary containing databaseAddress, username and password strings :type jwt_payload: dict :param schema_name: Name of schema to list all tables from :type schema_name: str :param table_name: Table name under the given schema; must be in camel case :type table_name: str :param tuple_to_restrict_by: Record to restrict the table by to delete :type tuple_to_restrict_by: dict :param cascade: Allow for cascading delete, defaults to False :type cascade: bool """ DJConnector.set_datajoint_config(jwt_payload) schema_virtual_module = dj.create_virtual_module( schema_name, schema_name) # Get all the table attributes and create a set table_attributes = set( getattr(schema_virtual_module, table_name).heading.primary_key + getattr(schema_virtual_module, table_name).heading.secondary_attributes) # Check to see if the restriction has at least one matching attribute, if not raise an # error if len(table_attributes & tuple_to_restrict_by.keys()) == 0: raise InvalidRestriction( 'Restriction is invalid: None of the attributes match') # Compute restriction tuple_to_delete = getattr(schema_virtual_module, table_name) & tuple_to_restrict_by # Check if there is only 1 tuple to delete otherwise raise error if len(tuple_to_delete) > 1: raise InvalidDeleteRequest( """Cannot delete more than 1 tuple at a time. Please update the restriction accordingly""") elif len(tuple_to_delete) == 0: raise InvalidDeleteRequest('Nothing to delete') # All check pass thus proceed to delete tuple_to_delete.delete( safemode=False) if cascade else tuple_to_delete.delete_quick()
def set_schema(): if os.environ['FLASK_MODE'] == 'TEST': print("FLASK_MODE is TEST") from schemas import condis dj.config['database.host'] = 'dbtest' dj.config['database.user'] = os.environ['DB_USER'] dj.config['database.password'] = os.environ['DB_PASS'] dj.config['database.port'] = 3306 # inside the db container # db = dj.create_virtual_module('test_db','test_db',create_schema=True) # creates the schema if it does not already exist. Can't add tables from within the app because create_schema=False db = dj.create_virtual_module('test_db','test_db') # creates the schema if it does not already exist. Can't add tables from within the app because create_schema=False if os.environ['FLASK_MODE'] == 'DEV': dj.config['database.host'] = 'dbdev' dj.config['database.user'] = os.environ['DB_USER'] dj.config['database.password'] = os.environ['DB_PASS'] dj.config['database.port'] = 3306 # inside the db container db = dj.create_virtual_module('dev_db','dev_db') # creates the schema if it does not already exist. Can't add tables from within the app because create_schema=False elif os.environ['FLASK_MODE'] == 'PROD': dj.config['database.host'] = 'db' dj.config['database.user'] = os.environ['DB_USER'] dj.config['database.password'] = os.environ['DB_PASS'] dj.config['database.port'] = 3306 # inside the db container db = dj.create_virtual_module('prod_db','prod_db') return db
def test_convert(): # Configure stores default_store = 'external' # naming the unnamed external store dj.config['stores'] = { default_store: dict(protocol='s3', endpoint=S3_CONN_INFO['endpoint'], bucket='migrate-test', location='store', access_key=S3_CONN_INFO['access_key'], secret_key=S3_CONN_INFO['secret_key']), 'shared': dict(protocol='s3', endpoint=S3_CONN_INFO['endpoint'], bucket='migrate-test', location='maps', access_key=S3_CONN_INFO['access_key'], secret_key=S3_CONN_INFO['secret_key']), 'local': dict(protocol='file', location=str( Path(os.path.expanduser('~'), 'temp', 'migrate-test'))) } dj.config['cache'] = str( Path(os.path.expanduser('~'), 'temp', 'dj-cache')) dj.config['database.password'] = CONN_INFO['password'] dj.config['database.user'] = CONN_INFO['user'] dj.config['database.host'] = CONN_INFO['host'] schema = dj.Schema('djtest_blob_migrate') # Test if migration throws unexpected exceptions _migrate_dj011_blob(schema, default_store) # Test Fetch test_mod = dj.create_virtual_module('test_mod', 'djtest_blob_migrate') r1 = test_mod.A.fetch('blob_share', order_by='id') assert_equal(r1[1][1], 2) # Test Insert test_mod.A.insert1({ 'id': 3, 'blob_external': [9, 8, 7, 6], 'blob_share': { 'number': 5 } }) r2 = (test_mod.A & 'id=3').fetch1() assert_equal(r2['blob_share']['number'], 5)
def fsck(sname): ''' check routine. XXX: could be faster/more granular if external checks done at lower level; this is skipped for code simplicity / modularity. ''' mod = dj.create_virtual_module(sname, sname) schema = mod.schema for rel in schema_iterator(schema): check_table(schema, rel)
def area(self, unit_keys): anatomy = dj.create_virtual_module("anatomy", "pipeline_anatomy") unit_area_df = pd.DataFrame( ((anatomy.AreaMembership & unit_keys).fetch( "animal_id", "session", "scan_idx", "brain_area", "unit_id", as_dict=True, ))) unit_df = pd.DataFrame(unit_keys) unit_df = unit_df.merge(unit_area_df, how="left") assert unit_df.area.notnull().all(), "Missing area for some units!" return unit_df.area.values
def layer(self, unit_keys): anatomy = dj.create_virtual_module("anatomy", "pipeline_anatomy") unit_layer_df = pd.DataFrame( ((anatomy.LayerMembership & unit_keys).fetch( "animal_id", "session", "scan_idx", "layer", "unit_id", as_dict=True, ))) unit_df = pd.DataFrame(unit_keys) unit_df = unit_df.merge(unit_layer_df, how="left") assert unit_df.layer.notnull().all( ), "Missing layer for some units!" return unit_df.layer.values
def test_convert(): # Configure stores default_store = "external" # naming the unnamed external store dj.config["stores"] = { default_store: dict( protocol="s3", endpoint=S3_CONN_INFO["endpoint"], bucket=S3_MIGRATE_BUCKET, location="store", access_key=S3_CONN_INFO["access_key"], secret_key=S3_CONN_INFO["secret_key"], ), "shared": dict( protocol="s3", endpoint=S3_CONN_INFO["endpoint"], bucket=S3_MIGRATE_BUCKET, location="maps", access_key=S3_CONN_INFO["access_key"], secret_key=S3_CONN_INFO["secret_key"], ), "local": dict( protocol="file", location=str(Path(os.path.expanduser("~"), "temp", S3_MIGRATE_BUCKET)), ), } dj.config["cache"] = str(Path(os.path.expanduser("~"), "temp", "dj-cache")) dj.config["database.password"] = CONN_INFO["password"] dj.config["database.user"] = CONN_INFO["user"] dj.config["database.host"] = CONN_INFO["host"] schema = dj.Schema("djtest_blob_migrate") # Test if migration throws unexpected exceptions _migrate_dj011_blob(schema, default_store) # Test Fetch test_mod = dj.create_virtual_module("test_mod", "djtest_blob_migrate") r1 = test_mod.A.fetch("blob_share", order_by="id") assert_equal(r1[1][1], 2) # Test Insert test_mod.A.insert1( {"id": 3, "blob_external": [9, 8, 7, 6], "blob_share": {"number": 5}} ) r2 = (test_mod.A & "id=3").fetch1() assert_equal(r2["blob_share"]["number"], 5)
def update_tuple(jwt_payload: dict, schema_name: str, table_name: str, tuple_to_update: dict): """ Update record as tuple into table :param jwt_payload: Dictionary containing databaseAddress, username and password strings :type jwt_payload: dict :param schema_name: Name of schema to list all tables from :type schema_name: str :param table_name: Table name under the given schema; must be in camel case :type table_name: str :param tuple_to_update: Record to be updated :type tuple_to_update: dict """ DJConnector.set_datajoint_config(jwt_payload) schema_virtual_module = dj.create_virtual_module( schema_name, schema_name) getattr(schema_virtual_module, table_name).update1(tuple_to_update)
def get_table_definition(jwt_payload: dict, schema_name: str, table_name: str): """ Get the table definition :param jwt_payload: Dictionary containing databaseAddress, username and password strings :type jwt_payload: dict :param schema_name: Name of schema to list all tables from :type schema_name: str :param table_name: Table name under the given schema; must be in camel case :type table_name: str :return: definition of the table :rtype: str """ DJConnector.set_datajoint_config(jwt_payload) schema_virtual_module = dj.create_virtual_module( schema_name, schema_name) return getattr(schema_virtual_module, table_name).describe()
def make(self, key): # Pull the Nidaq file/record session_dir = pathlib.Path(get_session_directory(key)) nidq_bin_full_path = list(session_dir.glob('*nidq.bin*'))[0] # And get the datajoint record behavior = dj.create_virtual_module('behavior', 'u19_behavior') thissession = behavior.TowersBlock().Trial() & key behavior_time, iterstart = thissession.fetch('trial_time', 'vi_start') # 1: load meta data, and the content of the NIDAQ file. Its content is digital. nidq_meta = readSGLX.readMeta(nidq_bin_full_path) nidq_sampling_rate = readSGLX.SampRate(nidq_meta) digital_array = ephys_utils.spice_glx_utility.load_spice_glx_digital_file(nidq_bin_full_path, nidq_meta) # Synchronize between pulses and get iteration # vector for each sample mode='counter_bit0' iteration_dict = ephys_utils.get_iteration_sample_vector_from_digital_lines_pulses(digital_array[1,:], digital_array[2,:], nidq_sampling_rate, behavior_time.shape[0], mode) # Check # of trials and iterations match status = ephys_utils.assert_iteration_samples_count(iteration_dict['iter_start_idx'], behavior_time) #They didn't match, try counter method (if available) if (not status) and (digital_array.shape[0] > 3): [framenumber_in_trial, trialnumber] = ephys_utils.behavior_sync_frame_counter_method(digital_array, behavior_time, thissession, nidq_sampling_rate, 3, 5) iteration_dict['framenumber_vector_samples'] = framenumber_in_trial iteration_dict['trialnumber_vector_samples'] = trialnumber final_key = dict(key, nidq_sampling_rate = nidq_sampling_rate, iteration_index_nidq = iteration_dict['framenumber_vector_samples'], trial_index_nidq = iteration_dict['trialnumber_vector_samples']) print(final_key) ephys.BehaviorSync.insert1(final_key,allow_direct_insert=True) self.insert_imec_sampling_rate(key, session_dir)
import hashlib import os.path as op import pathlib import pickle from shutil import copyfile from warnings import warn import datajoint as dj from tqdm import tqdm from .. import logger as log fuse = dj.create_virtual_module('fuse', 'pipeline_fuse') stimulus = dj.create_virtual_module('stimulus', 'pipeline_stimulus') def list_hash(values): """ Returns MD5 digest hash values for a list of values """ hashed = hashlib.md5() for v in values: hashed.update(str(v).encode()) return hashed.hexdigest() def key_hash(key): """ 32-byte hash used for lookup of primary keys of jobs """ hashed = hashlib.md5()
def test_virtual_module(): module = dj.create_virtual_module('module', schema.schema.database, connection=dj.conn(**CONN_INFO)) assert_true(issubclass(module.Experiment, UserTable))