Ejemplo n.º 1
0
 def record_dependency(jwt_payload: dict, schema_name: str, table_name: str,
                       primary_restriction: dict) -> list:
     """
     Return summary of dependencies associated with a restricted table
     :param jwt_payload: Dictionary containing databaseAddress, username and password
         strings
     :type jwt_payload: dict
     :param schema_name: Name of schema
     :type schema_name: str
     :param table_name: Table name under the given schema; must be in camel case
     :type table_name: str
     :param primary_restriction: Restriction to be applied to table
     :type primary_restriction: dict
     :return: Tables that are dependant on specific records. Includes accessibility and,
         if accessible, how many rows are affected.
     :rtype: list
     """
     DJConnector.set_datajoint_config(jwt_payload)
     virtual_module = dj.VirtualModule(schema_name, schema_name)
     table = getattr(virtual_module, table_name)
     # Retrieve dependencies of related to retricted
     dependencies = [
         dict(schema=descendant.database,
              table=descendant.table_name,
              accessible=True,
              count=len(descendant & primary_restriction))
         for descendant in table().descendants(as_objects=True)
     ]
     return dependencies
Ejemplo n.º 2
0
def virtual_module():
    dj.config['safemode'] = False
    connection = dj.conn(host=getenv('TEST_DB_SERVER'),
                         user=getenv('TEST_DB_USER'),
                         password=getenv('TEST_DB_PASS'),
                         reset=True)
    schema = dj.Schema('filter')

    @schema
    class Student(dj.Lookup):
        definition = """
        student_id: int
        ---
        student_name: varchar(50)
        student_ssn: varchar(20)
        student_enroll_date: datetime
        student_balance: float
        student_parking_lot=null : varchar(20)
        student_out_of_state: bool
        """
        contents = [(i, faker.name(), faker.ssn(),
                     faker.date_between_dates(date_start=date(2021, 1, 1),
                                              date_end=date(2021, 1, 31)),
                     round(randint(1000, 3000),
                           2), choice([None, 'LotA', 'LotB',
                                       'LotC']), bool(getrandbits(1)))
                    for i in range(100)]

    yield dj.VirtualModule('filter', 'filter')
    schema.drop()
    connection.close()
    dj.config['safemode'] = True
Ejemplo n.º 3
0
 def test_rename_non_dj_attribute():
     schema = PREFIX + "_test1"
     connection = dj.conn(**CONN_INFO)
     connection.query(
         f"CREATE TABLE {schema}.test_table (oldID int PRIMARY KEY)"
     ).fetchall()
     mySchema = dj.VirtualModule(schema, schema)
     assert ("oldID" not in mySchema.TestTable.proj(
         new_name="oldID").heading.attributes.keys()
             ), "Failed to rename attribute correctly"
     connection.query(f"DROP TABLE {schema}.test_table")
Ejemplo n.º 4
0
def test_delete_independent_without_cascade(token, client, connection, schemas_simple):
    schema_name = f'{SCHEMA_PREFIX}group1_simple'
    table_name = 'TableB'
    restriction = dict(a_id=1, b_id=21)
    vm = dj.VirtualModule('group1_simple', schema_name)
    REST_response = client.post(
        '/delete_tuple?cascade=fAlSe',
        headers=dict(Authorization=f'Bearer {token}'),
        json=dict(schemaName=schema_name,
                  tableName=table_name,
                  restrictionTuple=restriction))
    assert REST_response.status_code == 200
    assert len(getattr(vm, table_name) & restriction) == 0
Ejemplo n.º 5
0
def test_delete_invalid(token, client, connection, schemas_simple):
    schema_name = f'{SCHEMA_PREFIX}group1_simple'
    table_name = 'TableB'
    restriction = dict()
    vm = dj.VirtualModule('group1_simple', schema_name)
    REST_response = client.post(
        '/delete_tuple?cascade=TRUE',
        headers=dict(Authorization=f'Bearer {token}'),
        json=dict(schemaName=schema_name,
                  tableName=table_name,
                  restrictionTuple=restriction))
    assert REST_response.status_code == 500
    assert b'Restriction is invalid' in REST_response.data
    assert len(getattr(vm, table_name)()) == 3
Ejemplo n.º 6
0
def test_delete_independent_without_cascade(token, client, connection, schemas_simple):
    schema_name = f"{SCHEMA_PREFIX}group1_simple"
    table_name = "TableB"
    restriction = dict(a_id=1, b_id=21)
    filters = [
        dict(attributeName=k, operation="=", value=v) for k, v in restriction.items()
    ]
    encoded_filters = b64encode(dumps(filters).encode("utf-8")).decode("utf-8")
    q = dict(cascade="fAlSe", restriction=encoded_filters)
    vm = dj.VirtualModule("group1_simple", schema_name)
    REST_response = client.delete(
        f"/schema/{schema_name}/table/{table_name}/record?{urlencode(q)}",
        headers=dict(Authorization=f"Bearer {token}"),
    )
    assert REST_response.status_code == 200
    assert len(getattr(vm, table_name) & restriction) == 0
Ejemplo n.º 7
0
    def refresh_schema(self):
        """refresh container of schemas
        """
        schemata = {}
        for schema in dj.list_schemas():
            if schema in self["skip_schemas"]:
                continue
            # TODO error messages
            schemata[schema] = dj.VirtualModule(
                schema,
                schema,
                connection=self['connection'],
                add_objects=custom_attributes_dict,
                create_tables=True)
            # make sure jobs table has been created
            schemata[schema].schema.jobs

        self['schemata'] = schemata
Ejemplo n.º 8
0
    def _delete_records(
        jwt_payload: dict,
        schema_name: str,
        table_name: str,
        restriction: list = [],
        cascade: bool = False,
    ):
        """
        Delete a specific record based on the restriction given.

        :param jwt_payload: Dictionary containing databaseAddress, username, and password
            strings
        :type jwt_payload: dict
        :param schema_name: Name of schema
        :type schema_name: str
        :param table_name: Table name under the given schema; must be in camel case
        :type table_name: str
        :param restriction: Sequence of filters as ``dict`` with ``attributeName``,
            ``operation``, ``value`` keys defined, defaults to ``[]``
        :type restriction: list, optional
        :param cascade: Allow for cascading delete, defaults to ``False``
        :type cascade: bool, optional
        """
        _DJConnector._set_datajoint_config(jwt_payload)

        schema_virtual_module = dj.VirtualModule(schema_name, schema_name)

        # Get table object from name
        table = _DJConnector._get_table_object(schema_virtual_module,
                                               table_name)
        attributes = table.heading.attributes
        restrictions = [
            _DJConnector._filter_to_restriction(
                f, attributes[f["attributeName"]].type) for f in restriction
        ]

        # Compute restriction
        query = table & dj.AndList(restrictions)
        # Check if there is only 1 tuple to delete otherwise raise error
        if len(query) == 0:
            raise InvalidRestriction("Nothing to delete")

        # All check pass thus proceed to delete
        query.delete(safemode=False) if cascade else query.delete_quick()
Ejemplo n.º 9
0
    def _insert_tuple(jwt_payload: dict, schema_name: str, table_name: str,
                      tuple_to_insert: dict):
        """
        Insert record as tuple into table.

        :param jwt_payload: Dictionary containing databaseAddress, username, and password
            strings
        :type jwt_payload: dict
        :param schema_name: Name of schema
        :type schema_name: str
        :param table_name: Table name under the given schema; must be in camel case
        :type table_name: str
        :param tuple_to_insert: Record to be inserted
        :type tuple_to_insert: dict
        """
        _DJConnector._set_datajoint_config(jwt_payload)

        schema_virtual_module = dj.VirtualModule(schema_name, schema_name)
        getattr(schema_virtual_module, table_name).insert(tuple_to_insert)
Ejemplo n.º 10
0
    def get_table_definition(jwt_payload: dict, schema_name: str,
                             table_name: str):
        """
        Get the table definition
        :param jwt_payload: Dictionary containing databaseAddress, username and password
            strings
        :type jwt_payload: dict
        :param schema_name: Name of schema to list all tables from
        :type schema_name: str
        :param table_name: Table name under the given schema; must be in camel case
        :type table_name: str
        :return: definition of the table
        :rtype: str
        """
        DJConnector.set_datajoint_config(jwt_payload)

        local_values = locals()
        local_values[schema_name] = dj.VirtualModule(schema_name, schema_name)
        return getattr(local_values[schema_name], table_name).describe()
Ejemplo n.º 11
0
 def record_dependency(jwt_payload: dict, schema_name: str, table_name: str,
                       primary_restriction: dict) -> list:
     """
     Return summary of dependencies associated with a restricted table
     :param jwt_payload: Dictionary containing databaseAddress, username and password
         strings
     :type jwt_payload: dict
     :param schema_name: Name of schema
     :type schema_name: str
     :param table_name: Table name under the given schema; must be in camel case
     :type table_name: str
     :param primary_restriction: Restriction to be applied to table
     :type primary_restriction: dict
     :return: Tables that are dependant on specific records. Includes accessibility and,
         if accessible, how many rows are affected.
     :rtype: list
     """
     DJConnector.set_datajoint_config(jwt_payload)
     virtual_module = dj.VirtualModule(schema_name, schema_name)
     table = getattr(virtual_module, table_name)
     # Retrieve dependencies of related to retricted
     dependencies = [
         dict(schema=descendant.database,
              table=descendant.table_name,
              accessible=True,
              count=len(descendant & primary_restriction))
         for descendant in table().descendants(as_objects=True)
     ]
     # Determine first issue regarding access
     # Start transaction, try to delete, catch first occurrence, rollback
     virtual_module.schema.connection.start_transaction()
     try:
         (table & primary_restriction).delete(safemode=False,
                                              transaction=False)
     except AccessError as errors:
         dependencies = dependencies + [
             dict(TABLE_INFO_REGEX.match(errors.args[2]).groupdict(),
                  accessible=False)
         ]
     finally:
         virtual_module.schema.connection.cancel_transaction()
     return dependencies
Ejemplo n.º 12
0
    def refresh_schema(self):
        """refresh container of schemas
        """
        schemata = {}

        # direct loading if possible
        # TODO (also in app init)
        if self['init_database']:
            from loris.schema import (experimenters, core)

            schemata['experimenters'] = experimenters
            schemata['core'] = core

            if self['include_fly']:
                from loris.schema import anatomy, subjects
                schemata['anatomy'] = anatomy  # move out
                schemata['subjects'] = subjects

        for schema, module_path in self["import_schema_module"]:
            # TODO test
            spec = importlib.util.spec_from_file_location(schema, module_path)
            module = importlib.util.module_from_spec(spec)
            spec.loader.exec_module(module)
            schemata[schema] = module

        for schema in dj.list_schemas():
            if schema in self["skip_schemas"]:
                continue
            if schema in schemata:
                continue
            # TODO error messages
            schemata[schema] = dj.VirtualModule(
                schema,
                schema,
                connection=self['connection'],
                add_objects=custom_attributes_dict,
                create_tables=True)
            # make sure jobs table has been created
            schemata[schema].schema.jobs

        self['schemata'] = schemata
Ejemplo n.º 13
0
    def _record_dependency(jwt_payload: dict,
                           schema_name: str,
                           table_name: str,
                           restriction: list = []) -> list:
        """
        Return summary of dependencies associated with a restricted table. Will only show
        dependencies that user has access to.

        :param jwt_payload: Dictionary containing databaseAddress, username, and password
            strings
        :type jwt_payload: dict
        :param schema_name: Name of schema
        :type schema_name: str
        :param table_name: Table name under the given schema; must be in camel case
        :type table_name: str
        :param restriction: Sequence of filters as ``dict`` with ``attributeName``,
            ``operation``, ``value`` keys defined, defaults to ``[]``
        :type restriction: list
        :return: Tables that are dependent on specific records.
        :rtype: list
        """
        _DJConnector._set_datajoint_config(jwt_payload)
        virtual_module = dj.VirtualModule(schema_name, schema_name)
        table = getattr(virtual_module, table_name)
        attributes = table.heading.attributes
        # Retrieve dependencies of related to retricted
        dependencies = [
            dict(
                schema=descendant.database,
                table=descendant.table_name,
                accessible=True,
                count=len((table if descendant.full_table_name ==
                           table.full_table_name else descendant * table)
                          & dj.AndList([
                              _DJConnector._filter_to_restriction(
                                  f, attributes[f["attributeName"]].type)
                              for f in restriction
                          ])),
            ) for descendant in table().descendants(as_objects=True)
        ]
        return dependencies
Ejemplo n.º 14
0
    def __init__(self, name, component_config, static_config, jwt_payload: dict):
        lcls = locals()
        self.name = name
        if static_config:
            self.static_variables = types.MappingProxyType(static_config)
        if not all(k in component_config for k in ("x", "y", "height", "width")):
            self.mode = "dynamic"
        else:
            self.mode = "fixed"
            self.x = component_config["x"]
            self.y = component_config["y"]
            self.height = component_config["height"]
            self.width = component_config["width"]
        self.type = component_config["type"]
        self.route = component_config["route"]
        exec(component_config["dj_query"], globals(), lcls)
        self.dj_query = lcls["dj_query"]
        if self.attributes_route_format:
            self.attribute_route = self.attributes_route_format.format(
                route=component_config["route"]
            )
        if "restriction" in component_config:
            exec(component_config["restriction"], globals(), lcls)
            self.dj_restriction = lcls["restriction"]
        else:
            self.dj_restriction = lambda: dict()

        self.vm_list = [
            dj.VirtualModule(
                s,
                s,
                connection=dj.conn(
                    host=jwt_payload["databaseAddress"],
                    user=jwt_payload["username"],
                    password=jwt_payload["password"],
                    reset=True,
                ),
            )
            for s in inspect.getfullargspec(self.dj_query).args
        ]
Ejemplo n.º 15
0
    def _update_tuple(jwt_payload: dict, schema_name: str, table_name: str,
                      tuple_to_update: dict):
        """
        Update record as tuple into table.

        :param jwt_payload: Dictionary containing databaseAddress, username, and password
            strings
        :type jwt_payload: dict
        :param schema_name: Name of schema
        :type schema_name: str
        :param table_name: Table name under the given schema; must be in camel case
        :type table_name: str
        :param tuple_to_update: Record to be updated
        :type tuple_to_update: dict
        """
        conn = _DJConnector._set_datajoint_config(jwt_payload)

        schema_virtual_module = dj.VirtualModule(schema_name, schema_name)
        with conn.transaction:
            [
                getattr(schema_virtual_module, table_name).update1(t)
                for t in tuple_to_update
            ]
Ejemplo n.º 16
0
def test_uppercase_schema():
    # https://github.com/datajoint/datajoint-python/issues/564
    dj.conn(**CONN_INFO_ROOT, reset=True)
    schema1 = dj.Schema('Schema_A')

    @schema1
    class Subject(dj.Manual):
        definition = """
        name: varchar(32)
        """

    Schema_A = dj.VirtualModule('Schema_A', 'Schema_A')

    schema2 = dj.Schema('schema_b')

    @schema2
    class Recording(dj.Manual):
        definition = """
        -> Schema_A.Subject
        id: smallint
        """

    schema2.drop()
    schema1.drop()
Ejemplo n.º 17
0
        nx.barbell_graph(3, 1),
        nx.cycle_graph(5)
    ]
    c.insert((i, g) for i, g in enumerate(graphs))
    returned_graphs = c.fetch('conn_graph', order_by='connid')
    for g1, g2 in zip(graphs, returned_graphs):
        assert_true(isinstance(g2, nx.Graph))
        assert_equal(len(g1.edges), len(g2.edges))
        assert_true(0 == len(nx.symmetric_difference(g1, g2).edges))
    c.delete()
    dj.errors._switch_adapted_types(False)


# test with virtual module
virtual_module = dj.VirtualModule('virtual_module',
                                  adapted.schema_name,
                                  add_objects={'graph': graph})


def test_adapted_virtual():
    dj.errors._switch_adapted_types(True)
    c = virtual_module.Connectivity()
    graphs = [
        nx.lollipop_graph(4, 2),
        nx.star_graph(5),
        nx.barbell_graph(3, 1),
        nx.cycle_graph(5)
    ]
    c.insert((i, g) for i, g in enumerate(graphs))
    c.insert1({'connid': 100})  # test work with NULLs
    returned_graphs = c.fetch('conn_graph', order_by='connid')
def test_virtual_module():
    module = dj.VirtualModule('module',
                              schema.schema.database,
                              connection=dj.conn(**CONN_INFO))
    assert_true(issubclass(module.Experiment, UserTable))
Ejemplo n.º 19
0
        nx.barbell_graph(3, 1),
        nx.cycle_graph(5),
    ]
    c.insert((i, g) for i, g in enumerate(graphs))
    returned_graphs = c.fetch("conn_graph", order_by="connid")
    for g1, g2 in zip(graphs, returned_graphs):
        assert_true(isinstance(g2, nx.Graph))
        assert_equal(len(g1.edges), len(g2.edges))
        assert_true(0 == len(nx.symmetric_difference(g1, g2).edges))
    c.delete()
    dj.errors._switch_adapted_types(False)


# test with virtual module
virtual_module = dj.VirtualModule("virtual_module",
                                  adapted.schema_name,
                                  add_objects={"graph": graph})


def test_adapted_virtual():
    dj.errors._switch_adapted_types(True)
    c = virtual_module.Connectivity()
    graphs = [
        nx.lollipop_graph(4, 2),
        nx.star_graph(5),
        nx.barbell_graph(3, 1),
        nx.cycle_graph(5),
    ]
    c.insert((i, g) for i, g in enumerate(graphs))
    c.insert1({"connid": 100})  # test work with NULLs
    returned_graphs = c.fetch("conn_graph", order_by="connid")
Ejemplo n.º 20
0
def connection():
    dj.config['safemode'] = False
    connection = dj.conn(host=getenv('TEST_DB_SERVER'),
                         user=getenv('TEST_DB_USER'),
                         password=getenv('TEST_DB_PASS'),
                         reset=True)
    connection.query("""
                     CREATE USER IF NOT EXISTS 'underprivileged'@'%%'
                     IDENTIFIED BY 'datajoint';
                     """)
    connection.query(
        "GRANT ALL PRIVILEGES ON `deps`.* TO 'underprivileged'@'%%';")
    deps_secret = dj.VirtualModule('deps_secret',
                                   'deps_secret',
                                   create_tables=True)
    deps = dj.VirtualModule('deps', 'deps', create_tables=True)

    @deps.schema
    class TableA(dj.Lookup):
        definition = """
        a_id: int
        ---
        a_name: varchar(30)
        """
        contents = [(
            0,
            'Raphael',
        ), (
            1,
            'Bernie',
        )]

    @deps.schema
    class TableB(dj.Lookup):
        definition = """
        -> TableA
        b_id: int
        ---
        b_number: float
        """
        contents = [(0, 10, 22.12), (
            0,
            11,
            -1.21,
        ), (
            1,
            21,
            7.77,
        )]

    deps = dj.VirtualModule('deps', 'deps', create_tables=True)

    @deps_secret.schema
    class DiffTableB(dj.Lookup):
        definition = """
        -> deps.TableA
        bs_id: int
        ---
        bs_number: float
        """
        contents = [(0, -10, -99.99), (
            0,
            -11,
            287.11,
        )]

    @deps.schema
    class TableC(dj.Lookup):
        definition = """
        -> TableB
        c_id: int
        ---
        c_int: int
        """
        contents = [(0, 10, 100, -8), (
            0,
            11,
            200,
            -9,
        ), (
            0,
            11,
            300,
            -7,
        )]

    yield connection

    deps_secret.schema.drop()
    deps.schema.drop()
    connection.query("DROP USER 'underprivileged'@'%%';")
    connection.close()
    dj.config['safemode'] = True