예제 #1
0
class MySQLExplorerTest(CommonExplorerTestCases.CommonExplorerTests):
    pii_db_drop = """
        DROP TABLE full_pii;
        DROP TABLE partial_pii;
        DROP TABLE no_pii;
    """

    @staticmethod
    def execute_script(cursor, script):
        for query in script.split(";"):
            if len(query.strip()) > 0:
                cursor.execute(query)

    @pytest.fixture(scope="class")
    def create_tables(self, request):
        self.conn = pymysql.connect(
            host="127.0.0.1",
            user="******",
            password="******",
            database="piidb",
        )

        with self.conn.cursor() as cursor:
            self.execute_script(cursor, pii_data_script)
            cursor.execute("commit")
            cursor.close()

        def drop_tables():
            with self.conn.cursor() as cursor:
                self.execute_script(cursor, self.pii_db_drop)
                logging.info("Executed drop script")
                cursor.close()
            self.conn.close()

        request.addfinalizer(drop_tables)

    def setUp(self):
        self.explorer = MySQLExplorer(
            Namespace(
                host="127.0.0.1",
                user="******",
                password="******",
                database="piidb",
                include_schema=(),
                exclude_schema=(),
                include_table=(),
                exclude_table=(),
                catalog=None,
            ))

    def tearDown(self):
        self.explorer.get_connection().close()

    def test_schema(self):
        names = [sch.get_name() for sch in self.explorer.get_schemas()]
        self.assertEqual(["piidb"], names)
        return "piidb"

    def get_test_schema(self):
        return "piidb"
예제 #2
0
 def setUp(self):
     self.explorer = MySQLExplorer(
         Namespace(host="127.0.0.1",
                   user="******",
                   password="******",
                   database="piidb",
                   catalog=None))
예제 #3
0
 def setUp(self):
     self.explorer = MySQLExplorer(
         Namespace(host="127.0.0.1",
                   user="******",
                   password="******",
                   database="piidb",
                   include_schema=(),
                   exclude_schema=(),
                   include_table=(),
                   exclude_table=(),
                   catalog=None))
예제 #4
0
class MySQLDataTypeTest(CommonDataTypeTestCases.CommonDataTypeTests):
    char_db_drop = """
        DROP TABLE char_columns;
        DROP TABLE no_char_columns;
        DROP TABLE some_char_columns;
    """

    @staticmethod
    def execute_script(cursor, script):
        for query in script.split(";"):
            if len(query.strip()) > 0:
                cursor.execute(query)

    @pytest.fixture(scope="class")
    def create_tables(self, request):
        self.conn = pymysql.connect(host="127.0.0.1",
                                    user="******",
                                    password="******",
                                    database="piidb")

        with self.conn.cursor() as cursor:
            self.execute_script(cursor, char_data_types)
            cursor.execute("commit")
            cursor.close()

        def drop_tables():
            with self.conn.cursor() as drop_cursor:
                self.execute_script(drop_cursor, self.char_db_drop)
                logging.info("Executed drop script")
                drop_cursor.close()
            self.conn.close()

        request.addfinalizer(drop_tables)

    def setUp(self):
        self.explorer = MySQLExplorer(
            Namespace(
                host="127.0.0.1",
                user="******",
                password="******",
                database="piidb",
                include_schema=(),
                exclude_schema=(),
                include_table=(),
                exclude_table=(),
                catalog=None,
            ))

    def tearDown(self):
        self.explorer.get_connection().close()

    def get_test_schema(self):
        return "piidb"
예제 #5
0
 def test_mysql(self):
     self.assertEqual(
         "select c1,c2 from testSchema.t1",
         MySQLExplorer._get_select_query(
             self.schema,
             self.schema.get_children()[0],
             self.schema.get_children()[0].get_children()))
예제 #6
0
 def test_mysql(self):
     self.assertEqual(
         'select "c1","c2" from testSchema.t1 limit 10',
         MySQLExplorer._get_sample_query(
             self.schema,
             self.schema.get_children()[0],
             self.schema.get_children()[0].get_children(),
         ),
     )
예제 #7
0
 def explorer(self):
     return MySQLExplorer(self.namespace)
예제 #8
0
def scan_database(
        connection: Any,
        connection_type: str,
        scan_type: str = "shallow",
        include_schema: Tuple = (),
        exclude_schema: Tuple = (),
        include_table: Tuple = (),
        exclude_table: Tuple = (),
) -> Dict[Any, Any]:
    """
    Args:
        connection (connection): Connection object to a database
        connection_type (str): Database type. Can be one of sqlite, snowflake, athena, redshift, postgres, mysql or oracle
        scan_type (str): Choose deep(scan data) or shallow(scan column names only)
        include_schema (List[str]): Scan only schemas matching any pattern; When this option is not specified, all
                                    non-system schemas in the target database will be scanned. Also, the pattern is
                                    interpreted as a regular expression, so multiple schemas can also be selected
                                    by writing wildcard characters in the pattern.
        exclude_schema (List[str]): List of patterns. Do not scan any schemas matching any pattern. The pattern is
                                    interpreted according to the same rules as include_schema. When both include_schema
                                    and exclude_schema are given, the behavior is to dump just the schemas that
                                    match at least one include_schema pattern but no exclude_schema patterns. If only
                                    exclude_schema is specified, then matching schemas matching are excluded.
        include_table (List[str]):  List of patterns to match table. Similar in behaviour to include_schema.
        exclude_table (List[str]):  List of patterns to exclude matching table. Similar in behaviour to exclude_schema

    Returns:
        dict: A dictionary of schemata, tables and columns

    """

    scanner: Explorer
    if connection_type == "sqlite":
        args = Namespace(
            path=None,
            scan_type=scan_type,
            list_all=None,
            catalog=None,
            include_schema=include_schema,
            exclude_schema=exclude_schema,
            include_table=include_table,
            exclude_table=exclude_table,
        )

        scanner = SqliteExplorer(args)
    elif connection_type == "athena":
        args = Namespace(
            access_key=None,
            secret_key=None,
            staging_dir=None,
            region=None,
            scan_type=scan_type,
            list_all=None,
            include_schema=include_schema,
            exclude_schema=exclude_schema,
            include_table=include_table,
            exclude_table=exclude_table,
            catalog=None,
        )

        scanner = AthenaExplorer(args)
    #    elif connection_type == "snowflake":
    #        args = Namespace(
    #            account=None,
    #            warehouse=None,
    #            database=None,
    #            user=None,
    #            password=None,
    #            authenticator=None,
    #            okta_account_name=None,
    #            oauth_token=None,
    #            oauth_host=None,
    #            scan_type=scan_type,
    #            list_all=None,
    #            catalog=None,
    #            include_schema=include_schema,
    #            exclude_schema=exclude_schema,
    #            include_table=include_table,
    #            exclude_table=exclude_table,
    #        )

    #        scanner = SnowflakeExplorer(args)
    elif (connection_type == "mysql" or connection_type == "postgres"
          or connection_type == "redshift" or connection_type == "oracle"):
        ns = Namespace(
            host=None,
            port=None,
            user=None,
            password=None,
            database=None,
            connection_type=connection_type,
            scan_type=scan_type,
            list_all=None,
            catalog=None,
            include_schema=include_schema,
            exclude_schema=exclude_schema,
            include_table=include_table,
            exclude_table=exclude_table,
        )
        if ns.connection_type == "mysql":
            scanner = MySQLExplorer(ns)
        elif ns.connection_type == "postgres":
            scanner = PostgreSQLExplorer(ns)
        elif ns.connection_type == "redshift":
            scanner = RedshiftExplorer(ns)
        elif ns.connection_type == "oracle":
            scanner = OracleExplorer(ns)
    else:
        raise AttributeError(
            "Unknown connection type: {}".format(connection_type))

    scanner.connection = connection
    return _scan_db(scanner, scan_type)