Пример #1
0
 def __init__(self, protocol=False):
     self.setup, self.is_pi = socket.gethostname(), os.uname(
     )[4][:3] == 'arm'
     self.curr_state, self.lock, self.queue, self.curr_trial, self.total_reward = '', False, PriorityQueue(
     ), 0, 0
     self.session_key = dict()
     self.ping_timer, self.logger_timer = Timer(), Timer()
     self.setup_status = 'running' if protocol else 'ready'
     self.log_setup(protocol)
     fileobject = open(
         os.path.dirname(os.path.abspath(__file__)) + '/dj_local_conf.json')
     connect_info = json.loads(fileobject.read())
     background_conn = dj.Connection(connect_info['database.host'],
                                     connect_info['database.user'],
                                     connect_info['database.password'])
     self.schemata = dict()
     self.schemata['lab'] = dj.create_virtual_module(
         'beh.py', 'lab_behavior', connection=background_conn)
     self.schemata['mice'] = dj.create_virtual_module(
         'mice.py', 'lab_mice', connection=background_conn)
     self.thread_end, self.thread_lock = threading.Event(), threading.Lock()
     self.inserter_thread = threading.Thread(target=self.inserter)
     self.getter_thread = threading.Thread(target=self.getter)
     self.inserter_thread.start()
     self.getter_thread.start()
     self.logger_timer.start()  # start session time
Пример #2
0
    def load_traces_and_frametimes(self, key):
        # -- find number of recording depths
        pipe = (fuse.Activity() & key).fetch('pipe')
        assert len(
            np.unique(pipe)) == 1, 'Selection is from different pipelines'
        pipe = dj.create_virtual_module(pipe[0], 'pipeline_' + pipe[0])
        k = dict(key)
        k.pop('field', None)
        ndepth = len(dj.U('z') & (pipe.ScanInfo.Field() & k))
        frame_times = (stimulus.Sync()
                       & key).fetch1('frame_times').squeeze()[::ndepth]

        soma = pipe.MaskClassification.Type() & dict(type='soma')

        spikes = (dj.U('field', 'channel') * pipe.Activity.Trace() * StaticScan.Unit() \
                  * pipe.ScanSet.UnitInfo() & soma & key)
        traces, ms_delay, trace_keys = spikes.fetch(
            'trace',
            'ms_delay',
            dj.key,
            order_by='animal_id, session, scan_idx, unit_id')
        delay = np.fromiter(ms_delay / 1000, dtype=np.float)
        frame_times = (delay[:, None] + frame_times[None, :])
        traces = np.vstack(
            [fill_nans(tr.astype(np.float32)).squeeze() for tr in traces])
        traces, frame_times = self.adjust_trace_len(traces, frame_times)
        return traces, frame_times, trace_keys
Пример #3
0
 def make(self, key):
     self.insert(fuse.ScanDone() & key, ignore_extra_fields=True)
     pipe = (fuse.ScanDone() & key).fetch1('pipe')
     pipe = dj.create_virtual_module(pipe, 'pipeline_' + pipe)
     self.Unit().insert(fuse.ScanDone * pipe.ScanSet.Unit * pipe.MaskClassification.Type & key
                        & dict(pipe_version=1, segmentation_method=6, spike_method=5, type='soma'),
                        ignore_extra_fields=True)
Пример #4
0
def get_network_path(path_name):
    """
    Get network root path depnding on os and path required
    Args:
        path_name (str): One of the main paths for data storage (Bezos, braininit, u19_dj)
    Returns:
        network_path: (str): String with network path as mounted by the corresponding os
    """

    key = dict()
    # Check if path name to search starts with needed / at start
    if path_name[0] != '/':
        key['global_path'] = '/' + path_name
    else:
        key['global_path'] = path_name

    field_get = ['local_path']

    if is_this_spock():
        field_get = ['bucket_path']
        key['system'] = 'linux'
    elif sys.platform == "darwin":
        key['system'] = 'mac'
    elif os.name == 'nt':
        key['system'] = 'windows'
    else:
        key['system'] = 'linux'

    lab = dj.create_virtual_module(
        'lab', dj.config['custom']['database.prefix'] + 'lab')
    network_path = (lab.Path & key).fetch1(*field_get)
    return network_path
Пример #5
0
 def __init__(self, protocol=False):
     self.setup = socket.gethostname()
     self.is_pi = os.uname(
     )[4][:3] == 'arm' if os.name == 'posix' else False
     self.setup_status = 'running' if protocol else 'ready'
     fileobject = open(
         os.path.dirname(os.path.abspath(__file__)) +
         '/../dj_local_conf.json')
     con_info = json.loads(fileobject.read())
     self.private_conn = dj.Connection(con_info['database.host'],
                                       con_info['database.user'],
                                       con_info['database.password'])
     for schema, value in schemata.items(
     ):  # separate connection for internal comminication
         self._schemata.update({
             schema:
             dj.create_virtual_module(schema,
                                      value,
                                      connection=self.private_conn)
         })
     self.thread_end, self.thread_lock = threading.Event(), threading.Lock()
     self.inserter_thread = threading.Thread(target=self.inserter)
     self.getter_thread = threading.Thread(target=self.getter)
     self.inserter_thread.start()
     self.log_setup(protocol)
     self.getter_thread.start()
     self.logger_timer.start()
Пример #6
0
def depstick(sname, direction='reverse'):
    ''' check/print report of dependencies '''

    vm = dj.create_virtual_module(sname, sname)
    dbc = vm.schema.connection

    # Constraint CONSTRAINT_NAME in CONSTRAINT_SCHEMA TABLE_NAME refers to
    # table REFERENCED_TABLE_NAME in UNIQUE_CONSTRAINT_SCHEMA.

    if direction == 'forward':
        q = '''
        SELECT distinct(UNIQUE_CONSTRAINT_SCHEMA)
        FROM information_schema.REFERENTIAL_CONSTRAINTS
        where constraint_schema='{}';
        '''.format(sname)

    elif direction == 'reverse':
        q = '''
        SELECT distinct(CONSTRAINT_SCHEMA)
        FROM information_schema.REFERENTIAL_CONSTRAINTS
        where unique_constraint_schema='{}';
        '''.format(sname)

    else:
        raise Exception("depstick doesn't know {} direction."
                        .format(direction))

    print('-- {} {} dependencies'.format(sname, direction))
    for r in dbc.query(q):
        print(r[0]) if r[0] != sname else None
Пример #7
0
    def get_table_attributes(jwt_payload: dict, schema_name: str,
                             table_name: str):
        """
        Method to get primary and secondary attributes of a table
        :param jwt_payload: Dictionary containing databaseAddress, username and password
            strings
        :type jwt_payload: dict
        :param schema_name: Name of schema to list all tables from
        :type schema_name: str
        :param table_name: Table name under the given schema; must be in camel case
        :type table_name: str
        :return: Dict of primary, secondary attributes and with metadata: attribute_name,
            type, nullable, default, autoincrement.
        :rtype: dict
        """
        DJConnector.set_datajoint_config(jwt_payload)

        schema_virtual_module = dj.create_virtual_module(
            schema_name, schema_name)
        table_attributes = dict(primary_attributes=[], secondary_attributes=[])
        for attribute_name, attribute_info in getattr(
                schema_virtual_module, table_name).heading.attributes.items():
            if attribute_info.in_key:
                table_attributes['primary_attributes'].append(
                    (attribute_name, attribute_info.type,
                     attribute_info.nullable, attribute_info.default,
                     attribute_info.autoincrement))
            else:
                table_attributes['secondary_attributes'].append(
                    (attribute_name, attribute_info.type,
                     attribute_info.nullable, attribute_info.default,
                     attribute_info.autoincrement))

        return table_attributes
Пример #8
0
    def _make_tuples(self, key):
        # we are not making this a direct dependency because that database is not Datajoint compatible
        animals = dj.create_virtual_module('census', 'animal_keeping')

        a = scan_info(key['cell_id'], basedir=BASEDIR)
        a = a['Subject'] if 'Subject' in a else a['Recording']['Subject']

        if not animals.CensusSubject() & dict(name=a['Identifier']):
            if 'zucht' in a['Identifier']:
                tmp = a['Identifier'].split('leptozucht')
                tmp[-1] = '{:04d}'.format(int(tmp[-1]))
                a['Identifier'] = 'leptozucht'.join(tmp)
            else:
                tmp = a['Identifier'].split('lepto')
                tmp[-1] = '{:04d}'.format(int(tmp[-1]))
                a['Identifier'] = 'lepto'.join(tmp)
        fish_id = (animals.CensusSubject()
                   & dict(name=a['Identifier'])).fetch1('id')

        self.insert1(
            dict(key,
                 fish_id=fish_id,
                 eod_frequency=float(a['EOD Frequency'][:-2]),
                 gender=a['Gender'].lower(),
                 weight=float(a['Weight'][:-1]),
                 size=float(a['Size'][:-2])))
Пример #9
0
 def make_vmods(tag, cfg, connection):
     return {
         k: dj.create_virtual_module('{}_{}'.format(tag, k),
                                     v,
                                     connection=connection)
         for k, v in cfg.items()
     }
Пример #10
0
def wheel_metrics(trials, session):

    session_duration = pd.DataFrame(trials.groupby(
        ['subject_uuid', 'session_start_time'])['trial_start_time'].max())
    session_duration = session_duration.reset_index(level=[0, 1])
    session_duration = session_duration.rename(columns={'trial_start_time': "session_duration"})

    # Data with added column for session duration
    data2 = trials.merge(session_duration, on=['subject_uuid', 'session_start_time'])

    # --Wheel data
    dj_wheel = dj.create_virtual_module('wheel_moves', 'group_shared_wheel')

    movements_summary = pd.DataFrame.from_dict(dj_wheel.WheelMoveSet().fetch(as_dict=True))
    movements_summary = movements_summary.merge(data2, on=['subject_uuid', 'session_start_time'])
    movements_summary['disp_norm'] = np.abs(movements_summary['total_displacement'] /
                                            movements_summary['total_distance'])
    movements_summary['moves_time'] = (movements_summary['total_distance'] /
                                       movements_summary['session_duration'])

    # -- df
    # data should have same mice as data2, but movements_summary might have less
    mice = movements_summary['subject_uuid'].unique()
    df = pd.DataFrame(columns=['subject_uuid', 'disp_norm', 'moves_time'], index=range(len(mice)))
    for m, mouse in enumerate(mice):

        mouse_data = movements_summary.loc[movements_summary.subject_uuid == mouse]

        df['subject_uuid'][m] = mouse
        df['disp_norm'][m] = np.nanmean(mouse_data.loc[mouse_data['training_day'] <= session,
                                                       'disp_norm'])
        df['moves_time'][m] = np.nanmean(mouse_data.loc[mouse_data['training_day'] <= session,
                                                        'moves_time'])

    return df
Пример #11
0
    def get_virtual_module(full_table_name, context=None):

        if not context:
            context = inspect.currentframe().f_back.f_locals
        schema_name = re.match('`(.*)`\.', full_table_name).group(1)
        vmod = dj.create_virtual_module(schema_name, schema_name)
        context[vmod.__name__] = vmod
        return vmod.__name__
Пример #12
0
 def load_frame_times(self, key):
     pipe = (fuse.Activity() & key).fetch('pipe')
     assert len(np.unique(pipe)) == 1, 'Selection is from different pipelines'
     pipe = dj.create_virtual_module(pipe[0], 'pipeline_' + pipe[0])
     k = dict(key)
     k.pop('field', None)
     ndepth = len(dj.U('z') & (pipe.ScanInfo.Field() & k))
     return (stimulus.Sync() & key).fetch1('frame_times').squeeze()[::ndepth]
Пример #13
0
def erd(*args):
    report = dj.create_virtual_module('report', get_schema_name('report'))
    mods = (ephys, lab, experiment, tracking, psth, ccf, histology, report,
            publication)
    for mod in mods:
        modname = str().join(mod.__name__.split('.')[1:])
        fname = os.path.join('images', '{}.png'.format(modname))
        print('saving', fname)
        dj.ERD(mod, context={modname: mod}).save(fname)
Пример #14
0
def datajoint_dot():
    from bokeh.models import Div
    from bokeh.layouts import layout
    from bokeh.models.widgets import Panel

    subject = dj.create_virtual_module('subject', 'u19_subject')
    action = dj.create_virtual_module('action', 'u19_action')
    acquisition = dj.create_virtual_module('acquisition', 'u19_acquisition')

    try:
        svg = (dj.Diagram(subject) + dj.Diagram(action) +
               dj.Diagram(acquisition)).make_dot().create_svg()
        div = Div(text='<object data={0}'.format(svg.decode('utf-8')))
        # for some reason div can handle incomplete tags, completing is has artifact.
    except:
        print(
            'Could not get diagram, did you install graphviz and pydotplus??')
        div = Div(text='installation not complete')
    return Panel(child=layout([div]), title='Overview')
Пример #15
0
 def load_behavior_timing(self, key):
     log.info('Loading behavior frametimes')
     # -- find number of recording depths
     pipe = (fuse.Activity() & key).fetch('pipe')
     assert len(np.unique(pipe)) == 1, 'Selection is from different pipelines'
     pipe = dj.create_virtual_module(pipe[0], 'pipeline_' + pipe[0])
     k = dict(key)
     k.pop('field', None)
     ndepth = len(dj.U('z') & (pipe.ScanInfo.Field() & k))
     return (stimulus.BehaviorSync() & key).fetch1('frame_times').squeeze()[0::ndepth]
def visualized_3d(results, matchings, base_layout, eID, probe):
    ephys = dj.create_virtual_module('ephys', 'ibl_ephys')
    histology = dj.create_virtual_module('histology', 'ibl_histology')
    key_session = [{'session_uuid': UUID(eID)}]
    key = (acquisition.Session & ephys.DefaultCluster & key_session).fetch(
        'KEY', limit=1)

    if probe != 'both':
        key[0]['probe_idx'] = probe
    clusters = (ephys.DefaultCluster & key
                & histology.ChannelBrainLocationTemp).fetch('KEY')
    cluster_coords = []
    for key in tqdm(clusters):
        channel_raw_inds, channel_local_coordinates = \
            (ephys.ChannelGroup & key).fetch1(
                'channel_raw_inds', 'channel_local_coordinates')
        channel = (ephys.DefaultCluster & key).fetch1('cluster_channel')
        if channel in channel_raw_inds:
            channel_coords = (np.squeeze(
                channel_local_coordinates[channel_raw_inds == channel]))

            # get the Location with highest provenance
            q = histology.ChannelBrainLocationTemp & key & \
                dict(channel_lateral=channel_coords[0],
                    channel_axial=channel_coords[1]) & 'provenance=70'
            if q:
                cluster_coords.append(
                    q.fetch('channel_x', 'channel_y', 'channel_z'))
            else:
                cluster_coords.append([0, 0, 0])
        else:
            cluster_coords.append([0, 0, 0])
    matchings.insert(base_layout, {})
    allegiance = belong_dictionary([i["partition"] for i in results],
                                   matchings, base_layout)
    location = results[0]["locations"]
    clusters = len(cluster_coords)
    final = []
    for i in range(clusters):
        final.append(cluster_coords[i] + [location[i]] + allegiance[i])
    return final
Пример #17
0
def configure_minnie(return_virtual_module=False,
                     create_if_missing=False,
                     host=None,
                     cache_path=None):
    verify_paths(create_if_missing=create_if_missing)
    set_configurations(host=host, cache_path=cache_path)

    if return_virtual_module:
        import datajoint as dj
        return dj.create_virtual_module('minnie',
                                        schema_name_m65,
                                        add_objects=adapter_objects)
Пример #18
0
def surgery_notification():
    # Sends notifications to specified slack channel, surgeon, and lab manager about any checkups that need to be done
    num_to_word = {
        1: 'one',
        2: 'two',
        3: 'three'
    }  # Used to figure out which column to look up for checkup date

    # Define all Slack notification variables
    slack_notification_channel = "#surgery_reminders"
    slack_manager = "camila"
    slacktable = dj.create_virtual_module('pipeline_notification',
                                          'pipeline_notification')
    domain, api_key = slacktable.SlackConnection.fetch1('domain', 'api_key')
    slack = Slacker(api_key, timeout=60)

    # Only fetch surgeries done 1 to 3 days ago
    lessthan_date_res = (datetime.today()).strftime("%Y-%m-%d")
    greaterthan_date_res = (datetime.today() -
                            timedelta(days=4)).strftime("%Y-%m-%d")
    restriction = 'surgery_outcome = "Survival" and date < "{}" and date > "{}"'.format(
        lessthan_date_res, greaterthan_date_res)
    surgery_data = (experiment.Surgery
                    & restriction).fetch(order_by='date DESC')

    for entry in surgery_data:
        status = (experiment.SurgeryStatus
                  & entry).fetch(order_by="timestamp DESC")[0]
        day_key = "day_" + num_to_word[(datetime.today().date() -
                                        entry['date']).days]

        edit_url = "<{}|Update Status Here>".format(
            url_for('main.surgery_update',
                    _external=True,
                    animal_id=entry['animal_id'],
                    surgery_id=entry['surgery_id']))
        if status['euthanized'] == 0 and status[day_key] == 0:
            manager_message = "{} needs to check animal {} in room {} for surgery on {}. {}".format(
                entry['username'].title(), entry['animal_id'],
                entry['mouse_room'], entry['date'], edit_url)
            ch_message = "<!channel> Reminder: " + manager_message
            slack.chat.post_message("@" + slack_manager, manager_message)
            slack.chat.post_message(slack_notification_channel, ch_message)
            if len(slacktable.SlackUser & entry) > 0:
                slackname = (slacktable.SlackUser & entry).fetch('slack_user')
                pm_message = "Don't forget to check on animal {} today! {}".format(
                    entry['animal_id'], edit_url)
                slack.chat.post_message("@" + slackname,
                                        pm_message,
                                        as_user=True)

    return '', http.HTTPStatus.NO_CONTENT
Пример #19
0
    def connect_to_datajoint(self):
        if self.is_connected:
            return True

        for key, value in self.config.items():
            dj.config[key] = value
        self.connection = dj.conn()
        self.is_connected = self.connection.is_connected
        if self.is_connected:
            for schema in dj.list_schemas():
                setattr(self, schema,
                        dj.create_virtual_module(f'{schema}.py', schema))
        return self.is_connected
Пример #20
0
    def delete_tuple(jwt_payload: dict,
                     schema_name: str,
                     table_name: str,
                     tuple_to_restrict_by: dict,
                     cascade: bool = False):
        """
        Delete a specific record based on the restriction given (Can only delete 1 at a time)
        :param jwt_payload: Dictionary containing databaseAddress, username and password
            strings
        :type jwt_payload: dict
        :param schema_name: Name of schema to list all tables from
        :type schema_name: str
        :param table_name: Table name under the given schema; must be in camel case
        :type table_name: str
        :param tuple_to_restrict_by: Record to restrict the table by to delete
        :type tuple_to_restrict_by: dict
        :param cascade: Allow for cascading delete, defaults to False
        :type cascade: bool
        """
        DJConnector.set_datajoint_config(jwt_payload)

        schema_virtual_module = dj.create_virtual_module(
            schema_name, schema_name)
        # Get all the table attributes and create a set
        table_attributes = set(
            getattr(schema_virtual_module, table_name).heading.primary_key +
            getattr(schema_virtual_module,
                    table_name).heading.secondary_attributes)

        # Check to see if the restriction has at least one matching attribute, if not raise an
        # error
        if len(table_attributes & tuple_to_restrict_by.keys()) == 0:
            raise InvalidRestriction(
                'Restriction is invalid: None of the attributes match')

        # Compute restriction
        tuple_to_delete = getattr(schema_virtual_module,
                                  table_name) & tuple_to_restrict_by

        # Check if there is only 1 tuple to delete otherwise raise error
        if len(tuple_to_delete) > 1:
            raise InvalidDeleteRequest(
                """Cannot delete more than 1 tuple at a time.
                            Please update the restriction accordingly""")
        elif len(tuple_to_delete) == 0:
            raise InvalidDeleteRequest('Nothing to delete')

        # All check pass thus proceed to delete
        tuple_to_delete.delete(
            safemode=False) if cascade else tuple_to_delete.delete_quick()
Пример #21
0
def set_schema():
	if os.environ['FLASK_MODE'] == 'TEST':
		print("FLASK_MODE is TEST")
		from schemas import condis
		dj.config['database.host'] = 'dbtest'
		dj.config['database.user'] = os.environ['DB_USER']
		dj.config['database.password'] = os.environ['DB_PASS']
		dj.config['database.port'] = 3306 # inside the db container
		# db = dj.create_virtual_module('test_db','test_db',create_schema=True) # creates the schema if it does not already exist. Can't add tables from within the app because create_schema=False
		db = dj.create_virtual_module('test_db','test_db') # creates the schema if it does not already exist. Can't add tables from within the app because create_schema=False
	if os.environ['FLASK_MODE'] == 'DEV':
		dj.config['database.host'] = 'dbdev'
		dj.config['database.user'] = os.environ['DB_USER']
		dj.config['database.password'] = os.environ['DB_PASS']
		dj.config['database.port'] = 3306 # inside the db container
		db = dj.create_virtual_module('dev_db','dev_db') # creates the schema if it does not already exist. Can't add tables from within the app because create_schema=False
	elif os.environ['FLASK_MODE'] == 'PROD':
		dj.config['database.host'] = 'db'
		dj.config['database.user'] = os.environ['DB_USER']
		dj.config['database.password'] = os.environ['DB_PASS']
		dj.config['database.port'] = 3306 # inside the db container
		db = dj.create_virtual_module('prod_db','prod_db') 
	return db
    def test_convert():
        # Configure stores
        default_store = 'external'  # naming the unnamed external store
        dj.config['stores'] = {
            default_store:
            dict(protocol='s3',
                 endpoint=S3_CONN_INFO['endpoint'],
                 bucket='migrate-test',
                 location='store',
                 access_key=S3_CONN_INFO['access_key'],
                 secret_key=S3_CONN_INFO['secret_key']),
            'shared':
            dict(protocol='s3',
                 endpoint=S3_CONN_INFO['endpoint'],
                 bucket='migrate-test',
                 location='maps',
                 access_key=S3_CONN_INFO['access_key'],
                 secret_key=S3_CONN_INFO['secret_key']),
            'local':
            dict(protocol='file',
                 location=str(
                     Path(os.path.expanduser('~'), 'temp', 'migrate-test')))
        }
        dj.config['cache'] = str(
            Path(os.path.expanduser('~'), 'temp', 'dj-cache'))

        dj.config['database.password'] = CONN_INFO['password']
        dj.config['database.user'] = CONN_INFO['user']
        dj.config['database.host'] = CONN_INFO['host']
        schema = dj.Schema('djtest_blob_migrate')

        # Test if migration throws unexpected exceptions
        _migrate_dj011_blob(schema, default_store)

        # Test Fetch
        test_mod = dj.create_virtual_module('test_mod', 'djtest_blob_migrate')
        r1 = test_mod.A.fetch('blob_share', order_by='id')
        assert_equal(r1[1][1], 2)

        # Test Insert
        test_mod.A.insert1({
            'id': 3,
            'blob_external': [9, 8, 7, 6],
            'blob_share': {
                'number': 5
            }
        })
        r2 = (test_mod.A & 'id=3').fetch1()
        assert_equal(r2['blob_share']['number'], 5)
Пример #23
0
def fsck(sname):
    '''
    check routine.

    XXX: could be faster/more granular if external checks done at lower level;
    this is skipped for code simplicity / modularity.
    '''

    mod = dj.create_virtual_module(sname, sname)

    schema = mod.schema

    for rel in schema_iterator(schema):

        check_table(schema, rel)
Пример #24
0
 def area(self, unit_keys):
     anatomy = dj.create_virtual_module("anatomy", "pipeline_anatomy")
     unit_area_df = pd.DataFrame(
         ((anatomy.AreaMembership & unit_keys).fetch(
             "animal_id",
             "session",
             "scan_idx",
             "brain_area",
             "unit_id",
             as_dict=True,
         )))
     unit_df = pd.DataFrame(unit_keys)
     unit_df = unit_df.merge(unit_area_df, how="left")
     assert unit_df.area.notnull().all(), "Missing area for some units!"
     return unit_df.area.values
Пример #25
0
 def layer(self, unit_keys):
     anatomy = dj.create_virtual_module("anatomy", "pipeline_anatomy")
     unit_layer_df = pd.DataFrame(
         ((anatomy.LayerMembership & unit_keys).fetch(
             "animal_id",
             "session",
             "scan_idx",
             "layer",
             "unit_id",
             as_dict=True,
         )))
     unit_df = pd.DataFrame(unit_keys)
     unit_df = unit_df.merge(unit_layer_df, how="left")
     assert unit_df.layer.notnull().all(
     ), "Missing layer for some units!"
     return unit_df.layer.values
Пример #26
0
    def test_convert():
        # Configure stores
        default_store = "external"  # naming the unnamed external store
        dj.config["stores"] = {
            default_store: dict(
                protocol="s3",
                endpoint=S3_CONN_INFO["endpoint"],
                bucket=S3_MIGRATE_BUCKET,
                location="store",
                access_key=S3_CONN_INFO["access_key"],
                secret_key=S3_CONN_INFO["secret_key"],
            ),
            "shared": dict(
                protocol="s3",
                endpoint=S3_CONN_INFO["endpoint"],
                bucket=S3_MIGRATE_BUCKET,
                location="maps",
                access_key=S3_CONN_INFO["access_key"],
                secret_key=S3_CONN_INFO["secret_key"],
            ),
            "local": dict(
                protocol="file",
                location=str(Path(os.path.expanduser("~"), "temp", S3_MIGRATE_BUCKET)),
            ),
        }
        dj.config["cache"] = str(Path(os.path.expanduser("~"), "temp", "dj-cache"))

        dj.config["database.password"] = CONN_INFO["password"]
        dj.config["database.user"] = CONN_INFO["user"]
        dj.config["database.host"] = CONN_INFO["host"]
        schema = dj.Schema("djtest_blob_migrate")

        # Test if migration throws unexpected exceptions
        _migrate_dj011_blob(schema, default_store)

        # Test Fetch
        test_mod = dj.create_virtual_module("test_mod", "djtest_blob_migrate")
        r1 = test_mod.A.fetch("blob_share", order_by="id")
        assert_equal(r1[1][1], 2)

        # Test Insert
        test_mod.A.insert1(
            {"id": 3, "blob_external": [9, 8, 7, 6], "blob_share": {"number": 5}}
        )
        r2 = (test_mod.A & "id=3").fetch1()
        assert_equal(r2["blob_share"]["number"], 5)
Пример #27
0
    def update_tuple(jwt_payload: dict, schema_name: str, table_name: str,
                     tuple_to_update: dict):
        """
        Update record as tuple into table
        :param jwt_payload: Dictionary containing databaseAddress, username and password
            strings
        :type jwt_payload: dict
        :param schema_name: Name of schema to list all tables from
        :type schema_name: str
        :param table_name: Table name under the given schema; must be in camel case
        :type table_name: str
        :param tuple_to_update: Record to be updated
        :type tuple_to_update: dict
        """
        DJConnector.set_datajoint_config(jwt_payload)

        schema_virtual_module = dj.create_virtual_module(
            schema_name, schema_name)
        getattr(schema_virtual_module, table_name).update1(tuple_to_update)
Пример #28
0
    def get_table_definition(jwt_payload: dict, schema_name: str,
                             table_name: str):
        """
        Get the table definition
        :param jwt_payload: Dictionary containing databaseAddress, username and password
            strings
        :type jwt_payload: dict
        :param schema_name: Name of schema to list all tables from
        :type schema_name: str
        :param table_name: Table name under the given schema; must be in camel case
        :type table_name: str
        :return: definition of the table
        :rtype: str
        """
        DJConnector.set_datajoint_config(jwt_payload)

        schema_virtual_module = dj.create_virtual_module(
            schema_name, schema_name)
        return getattr(schema_virtual_module, table_name).describe()
Пример #29
0
    def make(self, key):
        # Pull the Nidaq file/record
        session_dir = pathlib.Path(get_session_directory(key))
        nidq_bin_full_path = list(session_dir.glob('*nidq.bin*'))[0]
        # And get the datajoint record
        behavior = dj.create_virtual_module('behavior', 'u19_behavior')
        thissession = behavior.TowersBlock().Trial() & key
        behavior_time, iterstart = thissession.fetch('trial_time', 'vi_start')

        # 1: load meta data, and the content of the NIDAQ file. Its content is digital.
        nidq_meta          = readSGLX.readMeta(nidq_bin_full_path)
        nidq_sampling_rate = readSGLX.SampRate(nidq_meta)
        digital_array      = ephys_utils.spice_glx_utility.load_spice_glx_digital_file(nidq_bin_full_path, nidq_meta)

        # Synchronize between pulses and get iteration # vector for each sample
        mode='counter_bit0'
        iteration_dict = ephys_utils.get_iteration_sample_vector_from_digital_lines_pulses(digital_array[1,:], digital_array[2,:], nidq_sampling_rate, behavior_time.shape[0], mode)
        # Check # of trials and iterations match
        status = ephys_utils.assert_iteration_samples_count(iteration_dict['iter_start_idx'], behavior_time)

        #They didn't match, try counter method (if available)
        if (not status) and (digital_array.shape[0] > 3):
            [framenumber_in_trial, trialnumber] = ephys_utils.behavior_sync_frame_counter_method(digital_array, behavior_time, thissession, nidq_sampling_rate, 3, 5)
            iteration_dict['framenumber_vector_samples'] = framenumber_in_trial
            iteration_dict['trialnumber_vector_samples'] = trialnumber


        final_key = dict(key, nidq_sampling_rate = nidq_sampling_rate, 
               iteration_index_nidq = iteration_dict['framenumber_vector_samples'],
               trial_index_nidq = iteration_dict['trialnumber_vector_samples'])

        print(final_key)

        ephys.BehaviorSync.insert1(final_key,allow_direct_insert=True)

        self.insert_imec_sampling_rate(key, session_dir)
Пример #30
0
import hashlib
import os.path as op
import pathlib
import pickle
from shutil import copyfile
from warnings import warn

import datajoint as dj
from tqdm import tqdm

from .. import logger as log

fuse = dj.create_virtual_module('fuse', 'pipeline_fuse')
stimulus = dj.create_virtual_module('stimulus', 'pipeline_stimulus')


def list_hash(values):
    """
    Returns MD5 digest hash values for a list of values
    """
    hashed = hashlib.md5()
    for v in values:
        hashed.update(str(v).encode())
    return hashed.hexdigest()


def key_hash(key):
    """
    32-byte hash used for lookup of primary keys of jobs
    """
    hashed = hashlib.md5()
def test_virtual_module():
    module = dj.create_virtual_module('module', schema.schema.database, connection=dj.conn(**CONN_INFO))
    assert_true(issubclass(module.Experiment, UserTable))