Beispiel #1
0
class TemplatesLib(object):
    """Manage all the templates library interaction"""
    @inject(loader=AssistedBuilder(FileSystemLoader),
            environment=AssistedBuilder(Environment))
    def __init__(self, loader, environment):
        """
        Initialize the class
        :param loader: FilesystemLoader
        :param environment: Environment
        """
        super(TemplatesLib, self).__init__()

        self.__loader = loader
        self.__environment = environment

    def __get_lib(self, base_directory):
        return self.__environment.build(
            loader=self.__loader.build(searchpath=base_directory),
            trim_blocks=True,
            extensions=['jinja2.ext.do'])

    def render(self, base_directory, template_file, template_vars):
        """
        Render the given template with the template_vars
        :param base_directory: string Path where the template can be found
        :param template_file: string
        :param template_vars: dict
        :return: string
        """
        return self.__get_lib(base_directory).get_template(
            template_file).render(template_vars)
Beispiel #2
0
 def get_driver(self,
                injector,
                driver,
                db_host='localhost',
                db_user='******',
                db_pwd=None,
                db_port=None,
                db_name=None,
                db_schema=None):
     """
     Return a connection to the given database using selected driver. In case given driver is not implemented a NotImplementedError is raised
     :param injector: Injector
     :param driver: string
     :param db_host: string
     :param db_user: string
     :param db_pwd: string
     :param db_port: int
     :param db_name: string
     :param db_schema: string
     :return: Selected driver instance
     """
     try:
         return injector.get(
             AssistedBuilder('database_driver_' + driver.lower())).build(
                 db_host=db_host,
                 db_user=db_user,
                 db_pwd=db_pwd,
                 db_port=db_port,
                 db_name=db_name,
                 db_schema=db_schema)
     except Error:
         raise NotImplementedError(
             'Given driver has not been implemented on SlippinJ')
Beispiel #3
0
def test_assisted_builder_uses_bindings():
    Interface = Key('Interface')

    def configure(binder):
        binder.bind(Interface, to=NeedsAssistance)

    injector = Injector(configure)
    builder = injector.get(AssistedBuilder(Interface))
    x = builder.build(b=333)
    assert ((type(x), x.b) == (NeedsAssistance, 333))
Beispiel #4
0
def test_assisted_builder_uses_concrete_class_when_specified():
    class X(object):
        pass

    def configure(binder):
        # meant only to show that provider isn't called
        binder.bind(X, to=lambda: 1 / 0)

    injector = Injector(configure)
    builder = injector.get(AssistedBuilder(cls=X))
    builder.build()
Beispiel #5
0
class DefaultConfiguration(object):
    """Get the default configuration for the workflows"""

    @inject(deploy_configuration=AssistedBuilder(DeployConfiguration), configuration_parser='configuration_parser')
    def __init__(self, deploy_configuration, configuration_parser):
        """
        Initialize the class
        :param deploy_configuration: DeployConfiguration
        :param configuration_parser: ConfigParser
        """
        super(DefaultConfiguration, self).__init__()

        self.__deploy_configuration = deploy_configuration
        self.__configuration_parser = configuration_parser

    def get(self, environment, arguments, workflow_configuration):
        """
        Get configuration parameters that are common to the workflows
        :param environment: string
        :param arguments: Namespace
        :param workflow_configuration: dict
        :return: dict
        """
        default_variables = ['hive_metastore_bucket', 'hdfs_deploy_folder']
        default_configuration = {}
        interactive_provided = False
        deploy_configuration = self.__deploy_configuration.build(environment=environment,
                                                                 configuration_parser=self.__configuration_parser)
        args = vars(arguments)

        for variable in default_variables:
            if variable in args and False != args[variable]:
                default_configuration[variable] = args[variable]
                interactive_provided = True
            elif variable in workflow_configuration:
                default_configuration[variable] = workflow_configuration[variable]
            elif deploy_configuration.get(variable):
                default_configuration[variable] = deploy_configuration.get(variable)
            else:
                default_configuration[variable] = raw_input(
                    'Please, provide the {var_name} value: '.format(var_name=variable.replace('-', ' ')))
                interactive_provided = True

        if interactive_provided and 'y' == (
                raw_input('Would you like to save the provided information in the config file: [Y/N] ')).lower():
            for key in default_configuration:
                deploy_configuration.set(key, default_configuration[key])

        return default_configuration
Beispiel #6
0
def test_providers_arent_called_for_dependencies_that_are_already_provided():
    def configure(binder):
        binder.bind(int, to=lambda: 1 / 0)

    class A(object):
        @inject(i=int)
        def __init__(self, i):
            pass

    injector = Injector(configure)
    builder = injector.get(AssistedBuilder(A))

    with pytest.raises(ZeroDivisionError):
        builder.build()

    builder.build(i=3)
Beispiel #7
0
def main():
    root_logger = logging.getLogger()
    root_logger.addHandler(logging.StreamHandler())
    root_logger.level = logging.DEBUG

    parser = ArgumentParser(description='HAL bot')
    parser.add_argument('--adapter', dest='adapter', default='shell')
    parser.add_argument('--name', dest='name', default='HAL')
    arguments = parser.parse_args()

    injector = Injector([ApplicationModule])
    bot_builder = injector.get(AssistedBuilder(Bot))
    bot = bot_builder.build(name=arguments.name)

    _attach_adapter(bot, arguments.adapter)
    bot.run()
Beispiel #8
0
def test_special_interfaces_work_with_auto_bind_disabled():
    class InjectMe(object):
        pass

    def configure(binder):
        binder.bind(InjectMe, to=InstanceProvider(InjectMe()))

    injector = Injector(configure, auto_bind=False)

    # This line used to fail with:
    # Traceback (most recent call last):
    #   File "/projects/injector/injector_test.py", line 1171,
    #   in test_auto_bind_disabled_regressions
    #     injector.get(ProviderOf(InjectMe))
    #   File "/projects/injector/injector.py", line 687, in get
    #     binding = self.binder.get_binding(None, key)
    #   File "/projects/injector/injector.py", line 459, in get_binding
    #     raise UnsatisfiedRequirement(cls, key)
    # UnsatisfiedRequirement: unsatisfied requirement on
    # <injector.ProviderOf object at 0x10ff01550>
    injector.get(ProviderOf(InjectMe))

    # This used to fail with an error similar to the ProviderOf one
    injector.get(AssistedBuilder(cls=InjectMe))
Beispiel #9
0
 class X(object):
     @inject(builder=AssistedBuilder(NeedsAssistance))
     def __init__(self, builder):
         self.obj = builder.build(b=234)
Beispiel #10
0
def test_assisted_builder_works_when_got_directly_from_injector():
    injector = Injector()
    builder = injector.get(AssistedBuilder(NeedsAssistance))
    obj = builder.build(b=123)
    assert ((obj.a, obj.b) == (str(), 123))
Beispiel #11
0
def test_assisted_builder_injection_uses_the_same_binding_key_every_time():
    # if we have different BindingKey for every AssistedBuilder(...) we will get memory leak
    gen_key = lambda: BindingKey(AssistedBuilder(NeedsAssistance))
    assert gen_key() == gen_key()
Beispiel #12
0
class Oracle(object):
    """Wrapper to connect to Oracle Servers and get all the metastore information"""
    @inject(oracle=AssistedBuilder(callable=pyoracle.connect), logger='logger')
    def __init__(self,
                 oracle,
                 logger,
                 db_host=None,
                 db_user='******',
                 db_name=None,
                 db_schema=None,
                 db_pwd=None,
                 db_port=None):

        super(Oracle, self).__init__()

        self.__db_name = db_name
        self.__db_user = db_user
        self.__db_schema = db_schema
        self.__db_dsn = pyoracle.makedsn(
            host=db_host,
            port=int(db_port) if None != db_port else 1521,
            service_name=db_name)
        self.__conn = oracle.build(user=db_user,
                                   password=db_pwd,
                                   dsn=self.__db_dsn)
        if self.__db_schema is not None:
            cursor = self.__conn.cursor()
            cursor.execute(
                "ALTER SESSION SET CURRENT_SCHEMA = {schema}".format(
                    schema=self.__db_schema))

        self.__db_connection_string = 'jdbc:oracle:thin:@//' + db_host + (
            (':' + db_port) if db_port else '') + (
                ('/' + db_name) if db_name else '')

        self.__illegal_characters = re.compile(
            r'[\000-\010]|[\013-\014]|[\016-\037]|[\xa1]|[\xc1]|[\xc9]|[\xcd]|[\xd1]|[\xbf]|[\xda]|[\xdc]|[\xe1]|[\xf1]|[\xfa]|[\xf3]'
        )

        self.__logger = logger

    def __makedict(self, cursor):
        """
        Convert cx_oracle query result to be a dictionary
        """

        cols = [d[0] for d in cursor.description]

        def createrow(*args):
            return dict(zip(cols, args))

        return createrow

    def __join_tables_list(self, tables):
        return ','.join('\'%s\'' % table for table in tables)

    def __get_table_list(self, table_list_query=False):
        self.__logger.debug('Getting table list')
        query_with_db_schema = "= '{schema}'".format(schema=self.__db_schema)
        query = "SELECT DISTINCT table_name " \
                "FROM all_tables WHERE OWNER " \
                "{owner} {table_list_query}".format(owner=query_with_db_schema if self.__db_schema else "NOT LIKE '%SYS%' AND OWNER NOT LIKE 'APEX%'AND OWNER NOT LIKE 'XDB'" ,table_list_query=' AND ' + table_list_query if table_list_query else '')

        cursor = self.__conn.cursor()
        cursor.execute(query)
        cursor.rowfactory = self.__makedict(cursor)

        tablelist = map(lambda x: x['TABLE_NAME'], cursor.fetchall())
        self.__logger.debug(
            'Found {count} tables'.format(count=cursor.rowcount))

        return tablelist

    def __get_columns_for_tables(self, tables):
        self.__logger.debug('Getting columns information')

        query_with_owner = "AND owner = '{schema}'".format(
            schema=self.__db_schema)
        info_query = "SELECT table_name, column_name, data_type, data_length, nullable, data_default, data_scale " \
                     "FROM ALL_TAB_COLUMNS " \
                     "WHERE table_name IN ({tables}) " \
                     "{owner}" \
                     "ORDER BY COLUMN_ID".format(tables=self.__join_tables_list(tables), owner=query_with_owner if self.__db_schema else '')

        cursor = self.__conn.cursor()
        cursor.execute(info_query)
        cursor.rowfactory = self.__makedict(cursor)

        tables_information = {}
        for row in cursor.fetchall():
            self.__logger.debug('Columns found for table {table}'.format(
                table=row['TABLE_NAME']))
            if not row['TABLE_NAME'] in tables_information:
                tables_information[row['TABLE_NAME']] = {'columns': []}

            tables_information[row['TABLE_NAME']]['columns'].append({
                'column_name':
                row['COLUMN_NAME'],
                'data_type':
                row['DATA_TYPE'].lower(),
                'character_maximum_length':
                row['DATA_LENGTH'],
                'is_nullable':
                row['NULLABLE'],
                'column_default':
                row['DATA_DEFAULT'],
            })

        return tables_information

    def __get_count_for_tables(self, tables):

        tables_information = {}
        cursor = self.__conn.cursor()
        for table in tables:
            try:
                self.__logger.debug(
                    'Getting count for table {table}'.format(table=table))
                info_query = 'SELECT COUNT(*) FROM {table}'.format(table=table)
                cursor.execute(info_query)
                tables_information[table] = {'count': cursor.fetchone()[0]}
            except:
                self.__logger.debug(
                    'The count query for table {table} has fail'.format(
                        table=table))
                pass

        return tables_information

    def __get_top_for_tables(self, tables, top=30):

        tables_information = {}

        cursor = self.__conn.cursor()
        for table in tables:
            tables_information[table] = {'rows': []}
            if top > 0:
                try:
                    self.__logger.debug(
                        'Getting {top} rows for table {table}'.format(
                            top=top, table=table))
                    query = 'SELECT * FROM {table} WHERE ROWNUM < {top}'.format(
                        top=top, table=table)
                    cursor.execute(query)
                    for row in cursor.fetchall():
                        table_row = []
                        for column in row:
                            try:
                                if type(column) is unicode:
                                    column = unicodedata.normalize(
                                        'NFKD',
                                        column).encode('iso-8859-1', 'replace')

                                else:
                                    column = str(column).decode(
                                        'utf8', 'replace').encode(
                                            'iso-8859-1', 'replace')
                                    if self.__illegal_characters.search(
                                            column):
                                        column = re.sub(
                                            self.__illegal_characters, '?',
                                            column)

                                if column == 'None':
                                    column = 'NULL'

                            except:
                                column = 'Parse_error'

                            table_row.append(column)

                        tables_information[table]['rows'].append(table_row)

                except pyoracle.ProgrammingError:
                    tables_information[table]['rows'].append(
                        'Error getting table data {error}'.format(
                            error=pyoracle.ProgrammingError.message))

        return tables_information

    def get_all_tables_info(self, table_list, table_list_query, top_max):
        """
        Return all the tables information reading from the Information Schema database
        :param table_list: string
        :param table_list_query: string
        :param top_max: integer
        :return: dict
        """

        if table_list:
            tables = map(lambda x: unicode(x), table_list.split(','))
        else:
            tables = self.__get_table_list(table_list_query)

        tables_counts = self.__get_count_for_tables(tables)
        tables_columns = self.__get_columns_for_tables(tables)
        tables_top = self.__get_top_for_tables(tables, top_max)
        tables_info = {'tables': {}}
        for table in tables_counts:
            tables_info['tables'][table] = {}
            tables_info['tables'][table].update(tables_columns[table])
            tables_info['tables'][table].update(tables_counts[table])
            tables_info['tables'][table].update(tables_top[table])

        tables_info['db_connection_string'] = self.__db_connection_string

        return tables_info
Beispiel #13
0
class Sqlserver(object):
    """Wrapper to connect to SQL Servers and get all the metastore information"""
    @inject(mssql=AssistedBuilder(callable=pymssql.connect), logger='logger')
    def __init__(self,
                 mssql,
                 logger,
                 db_host=None,
                 db_user='******',
                 db_name=None,
                 db_schema=None,
                 db_pwd=None,
                 db_port=None):
        """
        Initialize the SQLServer driver to get all the tables information
        :param mssql: Pymssql
        :param logger: Logger
        :param db_host: string
        :param db_user: string
        :param db_name: string
        :param db_schema: string
        :param db_pwd: string
        :param db_port: int
        """
        super(Sqlserver, self).__init__()

        self.__db_name = db_name
        self.__db_schema = db_schema if None != db_schema else 'dbo'
        self.__conn = mssql.build(server=db_host,
                                  user=db_user,
                                  password=db_pwd,
                                  database=db_name,
                                  tds_version='7.0',
                                  port=db_port if None != db_port else 1433)

        self.__db_connection_string = 'jdbc:sqlserver://' + db_host + (
            (':' + db_port) if db_port else '') + (
                (';DatabaseName=' + db_name) if db_name else '')

        self.__illegal_characters = re.compile(
            r'[\000-\010]|[\013-\014]|[\016-\037]|[\xa1]|[\xbf]|[\xc1]|[\xc9]|[\xcd]|[\xd1]|[\xbf]|[\xda]|[\xdc]|[\xe1]|[\xf1]|[\xfa]|[\xf3]'
        )

        self.__logger = logger

    def __join_tables_list(self, tables):
        return ','.join('\'%s\'' % table for table in tables)

    def __get_table_list(self, table_list_query=False):

        self.__logger.debug('Getting table list')
        query = 'SELECT table_name FROM information_schema.tables WHERE table_catalog = %(db_name)s and table_schema = %(schema)s {table_list_query}'.format(
            table_list_query=' AND ' +
            table_list_query if table_list_query else '')
        cursor = self.__conn.cursor(as_dict=True)
        cursor.execute(query, {
            'db_name': self.__db_name,
            'schema': self.__db_schema
        })

        self.__logger.debug(
            'Found {count} tables'.format(count=cursor.rowcount))

        return map(lambda x: x['table_name'], cursor.fetchall())

    def __get_columns_for_tables(self, tables):

        self.__logger.debug('Getting columns information')
        info_query = 'SELECT table_name, column_name, data_type, character_maximum_length, is_nullable, column_default FROM information_schema.columns WHERE table_name IN ({tables}) AND table_catalog=%(db_name)s AND table_schema=%(schema)s'.format(
            tables=self.__join_tables_list(tables))

        cursor = self.__conn.cursor(as_dict=True)
        cursor.execute(info_query, {
            'db_name': self.__db_name,
            'schema': self.__db_schema
        })

        tables_information = {}
        for row in cursor.fetchall():
            self.__logger.debug('Columns found for table {table}'.format(
                table=row['table_name']))
            if not row['table_name'] in tables_information:
                tables_information[row['table_name']] = {'columns': []}

            tables_information[row['table_name']]['columns'].append({
                'column_name':
                row['column_name'],
                'data_type':
                row['data_type'],
                'character_maximum_length':
                row['character_maximum_length'],
                'is_nullable':
                row['is_nullable'],
                'column_default':
                row['column_default'],
            })

        return tables_information

    def __get_count_for_tables(self, tables):

        tables_information = {}
        cursor = self.__conn.cursor()
        for table in tables:
            try:
                self.__logger.debug(
                    'Getting count for table {table}'.format(table=table))
                info_query = 'SELECT COUNT(*) FROM [{schema}].[{table}]'.format(
                    table=table, schema=self.__db_schema)
                cursor.execute(info_query)
                tables_information[table] = {'count': cursor.fetchone()[0]}
            except:
                pass

        return tables_information

    def __get_top_for_tables(self, tables, top=30):

        tables_information = {}

        cursor = self.__conn.cursor()
        for table in tables:
            tables_information[table] = {'rows': []}
            if top > 0:
                try:
                    self.__logger.debug(
                        'Getting {top} rows for table {table}'.format(
                            top=top, table=table))
                    cursor.execute(
                        'SELECT TOP {top} * FROM [{schema}].[{table}]'.format(
                            top=top, table=table, schema=self.__db_schema))
                    for row in cursor.fetchall():
                        table_row = []
                        for column in row:
                            try:
                                if type(column) is unicode:
                                    column = unicodedata.normalize(
                                        'NFKD',
                                        column).encode('iso-8859-1', 'replace')
                                else:
                                    column = str(column).decode(
                                        'utf8', 'replace').encode(
                                            'iso-8859-1', 'replace')
                                    if self.__illegal_characters.search(
                                            column):
                                        column = re.sub(
                                            self.__illegal_characters, '?',
                                            column)
                                if column == 'None':
                                    column = 'NULL'
                            except:
                                column = 'Parse_error'

                            table_row.append(column)

                        tables_information[table]['rows'].append(table_row)

                except pymssql.ProgrammingError:
                    tables_information[table]['rows'].append(
                        'Error getting table data {error}'.format(
                            error=pymssql.ProgrammingError.message))

        return tables_information

    def get_all_tables_info(self, table_list, table_list_query, top_max):
        """
        Return all the tables information reading from the Information Schema database
        :param table_list: string
        :param table_list_query: string
        :param top_max: integer
        :return: dict
        """

        if table_list:
            tables = map(lambda x: unicode(x), table_list.split(','))
        else:
            tables = self.__get_table_list(table_list_query)

        tables_counts = self.__get_count_for_tables(tables)
        tables_columns = self.__get_columns_for_tables(tables)
        tables_top = self.__get_top_for_tables(tables, top_max)

        tables_info = {'tables': {}}
        for table in tables_counts:
            tables_info['tables'][table] = {}
            tables_info['tables'][table].update(tables_columns[table])
            tables_info['tables'][table].update(tables_counts[table])
            tables_info['tables'][table].update(tables_top[table])

        tables_info['db_connection_string'] = self.__db_connection_string

        return tables_info
Beispiel #14
0
class EmrCluster(object):
    """Handle all EMR cluster information and configuration"""
    @inject(aws_emr_client='aws_emr_client',
            ssh_client=AssistedBuilder(SSHClient),
            args='args',
            logger='logger')
    def __init__(self, aws_emr_client, ssh_client, args, logger):
        """
        Initialize the class
        :param aws_emr_client: EMR.client
        :param ssh_client: paramiko.SSHClient
        :param args: Namespace
        :param logger: Logging
        """
        super(EmrCluster, self).__init__()
        self.__aws_emr_client = aws_emr_client
        self.__ssh_client = ssh_client.build(pem_files_dir=args.pem_dir)
        self.__cluster_information = {}
        self.__logger = logger

    def __get_cluster_environment(self, tags):
        for tag in tags:
            if 'environment' == tag['Key'].lower():
                return tag['Value']

        return False

    def __get_cluster(self, cluster_id):
        cluster = self.__aws_emr_client.describe_cluster(ClusterId=cluster_id)
        instances = self.__aws_emr_client.list_instances(
            ClusterId=cluster_id, InstanceGroupTypes=['MASTER'])

        return dict(cluster.items() + instances.items())

    def __get_master_ips(self, cluster_information):
        for instance in cluster_information['Instances']:
            if cluster_information['Cluster'][
                    'MasterPublicDnsName'] == instance['PublicDnsName']:
                return instance['PublicIpAddress'], instance[
                    'PrivateIpAddress']

    def get_cluster_information(self, cluster_id):
        """
        Get the cluster information from the AWS API
        :param cluster_id: string
        :return: dict
        """
        if not cluster_id in self.__cluster_information:
            self.__logger.debug(
                'Getting information from AWS for cluster {cluster_id}'.format(
                    cluster_id=cluster_id))
            cluster_information = self.__get_cluster(cluster_id)

            cluster_ips = self.__get_master_ips(cluster_information)

            self.__cluster_information[cluster_id] = {
                'public_dns':
                cluster_information['Cluster']['MasterPublicDnsName'],
                'public_ip':
                cluster_ips[0],
                'private_ip':
                cluster_ips[1],
                'environment':
                self.__get_cluster_environment(
                    cluster_information['Cluster']['Tags']),
                'key_name':
                cluster_information['Cluster']['Ec2InstanceAttributes']
                ['Ec2KeyName']
            }

        return self.__cluster_information[cluster_id]

    def exec_command(self, command, cluster_id, stop_on_error=False):
        """
        Execute given command in the master of the selected cluster
        :param command: string
        :param cluster_id: string
        :param stop_on_error: boolean
        :return: string
        """
        cluster_information = self.get_cluster_information(cluster_id)

        self.__logger.debug(
            'Executing command {command} in cluster {cluster_id}'.format(
                command=command, cluster_id=cluster_id))
        return self.__ssh_client.exec_command(
            command, cluster_information['public_dns'],
            cluster_information['key_name'], stop_on_error)

    def open_sftp(self, cluster_id):
        """
        Open an SFTP connection to the given cluster
        :param cluster_id: string
        :return: SFTP
        """
        cluster_information = self.get_cluster_information(cluster_id)

        self.__logger.debug(
            'Opening SFTP connection to cluster {cluster_id}'.format(
                cluster_id=cluster_id))
        return self.__ssh_client.open_sftp(cluster_information['public_dns'],
                                           cluster_information['key_name'])

    def get_pem_path(self, cluster_id):
        """
        Get the path to the private key file to connect to the cluster
        :param cluster_id: string
        :return: string
        """
        cluster_information = self.get_cluster_information(cluster_id)

        return self.__ssh_client.get_pem_path(cluster_information['key_name'])
Beispiel #15
0
def test_assisted_builder_accepts_callables():
    injector = Injector()
    builder = injector.get(AssistedBuilder(callable=lambda x: x * 2))
    assert builder.build(x=3) == 6
Beispiel #16
0
 class X(object):
     @inject(builder=AssistedBuilder(NeedsAssistance))
     def y(self, builder):
         return builder
Beispiel #17
0
class Postgresql(object):
    """Wrapper to connect to SQL Servers and get all the metastore information"""
    @inject(postgresql=AssistedBuilder(callable=psycopg2.connect),
            logger='logger')
    def __init__(self,
                 postgresql,
                 logger,
                 db_host=None,
                 db_user='******',
                 db_name=None,
                 db_schema=None,
                 db_pwd=None,
                 db_port=None):
        """
        Initialize the Postgresql driver to get all the tables information
        :param postgresql: Psycopg2
        :param logger: Logger
        :param db_host: string
        :param db_user: string
        :param db_name: string
        :param db_schema: string
        :param db_pwd: string
        :param db_port: int
        """
        super(Postgresql, self).__init__()

        self.__db_name = db_name
        self.__db_schema = db_schema if None != db_schema else 'public'
        self.__conn = postgresql.build(
            host=db_host,
            user=db_user,
            password=db_pwd,
            database=db_name,
            port=db_port if None != db_port else 5432)

        self.__column_types = {
            'timestamp without time zone': 'timestamp',
            'timestamp with time zone': 'timestamp',
            'uuid': 'string',
            'character': 'string',
            'character varying': 'string',
            'integer': 'int',
            'smallint': 'int',
            'text': 'string',
            'real': 'double',
            'numeric': 'double',
            'json': 'string',
            'USER-DEFINED': 'string'
        }

        self.__illegal_characters = re.compile(
            r'[\000-\010]|[\013-\014]|[\016-\037]|[\xa1]|[\xbf]|[\xc1]|[\xc9]|[\xcd]|[\xd1]|[\xbf]|[\xda]|[\xdc]|[\xe1]|[\xf1]|[\xfa]|[\xf3]'
        )

        self.__logger = logger

    def __join_tables_list(self, tables):
        return ','.join('\'%s\'' % table for table in tables)

    def __get_valid_column_name(self, column_name):
        return re.sub("[ ,;{}()\n\t=]", "", column_name)

    def __get_table_list(self, table_list_query=False):

        self.__logger.debug('Getting table list')
        query = 'SELECT table_name FROM information_schema.tables WHERE table_catalog = %(db_name)s and table_schema = %(db_schema)s {table_list_query}'.format(
            table_list_query=' AND ' +
            table_list_query if table_list_query else '')
        cursor = self.__conn.cursor(cursor_factory=psycopg2.extras.DictCursor)
        cursor.execute(query, {
            'db_name': self.__db_name,
            'db_schema': self.__db_schema
        })

        self.__logger.debug(
            'Found {count} tables'.format(count=cursor.rowcount))

        return map(lambda x: x[0], cursor.fetchall())

    def __get_tables_to_exclude(self, tables):
        return self.__get_table_list('table_name NOT IN ({tables})'.format(
            tables=self.__join_tables_list(tables)))

    def __get_database_collation(self):

        self.__logger.debug('Getting database collation')
        info_query = 'SELECT datcollate FROM pg_database WHERE datname = %(db_name)s'

        cursor = self.__conn.cursor()
        cursor.execute(info_query, {'db_name': self.__db_name})
        return cursor.fetchone()[0].lower()

    def __get_columns_for_tables(self, tables):

        self.__logger.debug('Getting columns information')
        info_query = 'SELECT table_name, column_name, data_type, character_maximum_length, is_nullable, column_default FROM information_schema.columns WHERE table_name IN ({tables}) AND table_catalog=%(db_name)s AND table_schema=%(db_schema)s'.format(
            tables=self.__join_tables_list(tables))

        cursor = self.__conn.cursor(cursor_factory=psycopg2.extras.DictCursor)
        cursor.execute(info_query, {
            'db_name': self.__db_name,
            'db_schema': self.__db_schema
        })

        tables_information = {}
        for row in cursor.fetchall():
            self.__logger.debug('Columns found for table {table}'.format(
                table=row['table_name']))
            if not row['table_name'] in tables_information:
                tables_information[row['table_name']] = {'columns': []}

            tables_information[row['table_name']]['columns'].append({
                'source_column_name':
                row['column_name'],
                'column_name':
                self.__get_valid_column_name(row['column_name']),
                'source_data_type':
                row['data_type'],
                'data_type':
                row['data_type'] if row['data_type'] not in self.__column_types
                else self.__column_types[row['data_type']],
                'character_maximum_length':
                row['character_maximum_length'],
                'is_nullable':
                row['is_nullable'],
                'column_default':
                row['column_default'],
            })

        return tables_information

    def __get_count_for_tables(self, tables):

        tables_information = {}
        cursor = self.__conn.cursor()
        for table in tables:
            try:
                self.__logger.debug(
                    'Getting count for table {table}'.format(table=table))
                info_query = 'SELECT COUNT(*) FROM {schema}.{table}'.format(
                    table=table, schema=self.__db_schema)
                cursor.execute(info_query)
                tables_information[table] = {'count': cursor.fetchone()[0]}
            except:
                pass

        return tables_information

    def __get_top_for_tables(self, tables, top=30):

        tables_information = {}

        utf8_collation = ('utf-8' or 'utf8') in self.__get_database_collation()

        cursor = self.__conn.cursor()

        for table in tables:
            tables_information[table] = {'rows': []}
            if top > 0:
                try:
                    self.__logger.debug(
                        'Getting {top} rows for table {table}'.format(
                            top=top, table=table))
                    info_query = 'SELECT * FROM {schema}.{table} LIMIT {top}'.format(
                        top=top, table=table, schema=self.__db_schema)
                    cursor.execute(info_query)
                    for row in cursor.fetchall():
                        table_row = []
                        for column in row:
                            if not utf8_collation:
                                try:
                                    if type(column) is unicode:
                                        column = unicodedata.normalize(
                                            'NFKD', column).encode(
                                                'iso-8859-1', 'replace')
                                    else:
                                        column = str(column).decode(
                                            'utf8', 'replace').encode(
                                                'iso-8859-1', 'replace')
                                        if self.__illegal_characters.search(
                                                column):
                                            column = re.sub(
                                                self.__illegal_characters, '?',
                                                column)
                                except:
                                    column = 'Parse_error'
                            if column == 'None':
                                column = 'NULL'
                            table_row.append(column)

                        tables_information[table]['rows'].append(table_row)

                except psycopg2.ProgrammingError:
                    tables_information[table]['rows'].append(
                        'Error getting table data {error}'.format(
                            error=psycopg2.ProgrammingError.message))

        return tables_information

    def get_all_tables_info(self, table_list, table_list_query, top_max):
        """
        Return all the tables information reading from the Information Schema database
        :param table_list: string
        :param table_list_query: string
        :param top_max: integer
        :return: dict
        """
        tables_to_exclude = {}

        if table_list:
            tables = table_list.split(',')
            tables_to_exclude = self.__get_tables_to_exclude(tables)
        else:
            tables = self.__get_table_list(table_list_query)

        tables_counts = self.__get_count_for_tables(tables)
        tables_columns = self.__get_columns_for_tables(tables)
        tables_top = self.__get_top_for_tables(tables, top_max)

        tables_info = {'tables': {}}
        for table in tables_counts:
            tables_info['tables'][table] = {}
            tables_info['tables'][table].update(tables_columns[table])
            tables_info['tables'][table].update(tables_counts[table])
            tables_info['tables'][table].update(tables_top[table])

        if tables_to_exclude:
            tables_info['excluded_tables'] = tables_to_exclude

        return tables_info
Beispiel #18
0
    def test_assisted_injection_works(self):
        builder = self.injector.get(AssistedBuilder(self.C))
        c = builder.build(noninjectable=5)

        assert ((type(c.injectable), c.noninjectable) == (self.A, 5))
Beispiel #19
0
class Mysql(object):
    """Wrapper to connect to MySQL Servers and get all the metastore information"""
    @inject(mysql=AssistedBuilder(callable=pymysql.connect), logger='logger')
    def __init__(self,
                 mysql,
                 logger,
                 db_host=None,
                 db_user='******',
                 db_name=None,
                 db_schema=None,
                 db_pwd=None,
                 db_port=None):
        """
        Initialize the MySQL driver to get all the tables information
        :param mysql: Mysql
        :param logger: Logger
        :param db_host: string
        :param db_user: string
        :param db_name: string
        :param db_schema: string
        :param db_pwd: string
        :param db_port: int
        """
        super(Mysql, self).__init__()

        self.__db_name = db_name
        self.__db_schema = db_schema if None != db_schema else 'mysql'
        self.__conn = mysql.build(
            host=db_host,
            user=db_user,
            password=db_pwd,
            database=db_name,
            port=int(db_port) if None != db_port else 3306)

        self.__column_types = {
            'tinyint': 'tinyint',
            'boolean': 'boolean',
            'smallint': 'smallint',
            'mediumint': 'int',
            'int': 'int',
            'integer': 'int',
            'bigint': 'bigint',
            'decimal': 'double',
            'dec': 'double',
            'numeric': 'double',
            'fixed': 'double',
            'float': 'float',
            'double': 'double',
            'real': 'double',
            'double precision': 'double',
            'bit': 'boolean',
            'char': 'string',
            'varchar': 'string',
            'binary': 'binary',
            'char byte': 'binary',
            'varbinary': 'binary',
            'tinyblob': 'binary',
            'blob': 'binary',
            'mediumblob': 'binary',
            'tinytext': 'string',
            'text': 'string',
            'mediumtext': 'string',
            'longtext': 'string',
            'enum': 'string',
            'date': 'timestamp',
            'time': 'timestamp',
            'datetime': 'timestamp',
            'timestamp': 'timestamp'
        }

        self.__illegal_characters = re.compile(
            r'[\000-\010]|[\013-\014]|[\016-\037]|[\xa1]|[\xbf]|[\xc1]|[\xc9]|[\xcd]|[\xd1]|[\xbf]|[\xda]|[\xdc]|[\xe1]|[\xf1]|[\xfa]|[\xf3]'
        )

        self.__logger = logger

    def __join_tables_list(self, tables):
        return ','.join('\'%s\'' % table for table in tables)

    def __get_valid_column_name(self, column_name):
        return re.sub("[ ,;{}()\n\t=]", "", column_name)

    def __get_table_list(self, table_list_query=False):
        self.__logger.debug('Getting table list')
        query = 'SELECT table_name FROM information_schema.tables WHERE table_schema = %(schema)s {table_list_query}'.format(
            table_list_query=' AND ' +
            table_list_query if table_list_query else '')

        cursor = self.__conn.cursor(pymysql.cursors.DictCursor)
        cursor.execute(query, {
            'db_name': self.__db_name,
            'schema': self.__db_name
        })

        self.__logger.debug(
            'Found {count} tables'.format(count=cursor.rowcount))

        return map(lambda x: x['table_name'], cursor.fetchall())

    def __get_tables_to_exclude(self, tables):
        return self.__get_table_list('table_name NOT IN ({tables})'.format(
            tables=self.__join_tables_list(tables)))

    def __get_columns_for_tables(self, tables):

        self.__logger.debug('Getting columns information')
        info_query = 'SELECT table_name, column_name, data_type, character_maximum_length, is_nullable, ' \
                     'column_default FROM information_schema.columns WHERE table_name IN ({tables}) AND table_schema=%(schema)s'.format(tables=self.__join_tables_list(tables))

        cursor = self.__conn.cursor(pymysql.cursors.DictCursor)
        cursor.execute(info_query, {
            'db_name': self.__db_name,
            'schema': self.__db_name
        })

        tables_information = {}
        for row in cursor.fetchall():
            self.__logger.debug('Columns found for table {table}'.format(
                table=row['table_name']))
            if not row['table_name'] in tables_information:
                tables_information[row['table_name']] = {'columns': []}

            tables_information[row['table_name']]['columns'].append({
                'source_column_name':
                row['column_name'],
                'column_name':
                self.__get_valid_column_name(row['column_name']),
                'source_data_type':
                row['data_type'],
                'data_type':
                row['data_type'] if row['data_type'] not in self.__column_types
                else self.__column_types[row['data_type']],
                'character_maximum_length':
                row['character_maximum_length'],
                'is_nullable':
                row['is_nullable'],
                'column_default':
                row['column_default'],
            })

        return tables_information

    def __get_count_for_tables(self, tables):

        tables_information = {}
        cursor = self.__conn.cursor()
        for table in tables:
            try:
                self.__logger.debug(
                    'Getting count for table {table}'.format(table=table))
                info_query = 'SELECT COUNT(*) FROM {schema}.{table}'.format(
                    table=table, schema=self.__db_name)
                cursor.execute(info_query)
                tables_information[table] = {'count': cursor.fetchone()[0]}
            except:
                self.__logger.debug(
                    'The count query for table {table} has fail'.format(
                        table=table))
                pass

        return tables_information

    def __get_top_for_tables(self, tables, top=30):

        tables_information = {}

        cursor = self.__conn.cursor()
        for table in tables:
            tables_information[table] = {'rows': []}
            if top > 0:
                try:
                    self.__logger.debug(
                        'Getting {top} rows for table {table}'.format(
                            top=top, table=table))
                    cursor.execute(
                        'SELECT * FROM {schema}.{table} LIMIT {top}'.format(
                            top=top, table=table, schema=self.__db_name))

                    for row in cursor.fetchall():
                        table_row = []
                        for column in row:
                            try:
                                if type(column) is unicode:
                                    column = unicodedata.normalize(
                                        'NFKD',
                                        column).encode('iso-8859-1', 'replace')

                                else:
                                    column = str(column).decode(
                                        'utf8', 'replace').encode(
                                            'iso-8859-1', 'replace')
                                    if self.__illegal_characters.search(
                                            column):
                                        column = re.sub(
                                            self.__illegal_characters, '?',
                                            column)

                                if column == 'None':
                                    column = 'NULL'

                            except:
                                column = 'Parse_error'

                            table_row.append(column)

                        tables_information[table]['rows'].append(table_row)

                except pymysql.ProgrammingError:
                    tables_information[table]['rows'].append(
                        'Error getting table data {error}'.format(
                            error=pymysql.ProgrammingError.message))

        return tables_information

    def get_all_tables_info(self, table_list, table_list_query, top_max):
        """
        Return all the tables information reading from the Information Schema database
        :param table_list: string
        :param table_list_query: string
        :param top_max: integer
        :return: dict
        """
        tables_to_exclude = {}

        if table_list:
            tables = map(lambda x: unicode(x), table_list.split(','))
            tables_to_exclude = self.__get_tables_to_exclude(tables)
        else:
            tables = self.__get_table_list(table_list_query)

        tables_counts = self.__get_count_for_tables(tables)
        tables_columns = self.__get_columns_for_tables(tables)
        tables_top = self.__get_top_for_tables(tables, top_max)

        tables_info = {'tables': {}}
        for table in tables_counts:
            tables_info['tables'][table] = {}
            tables_info['tables'][table].update(tables_columns[table])
            tables_info['tables'][table].update(tables_counts[table])
            tables_info['tables'][table].update(tables_top[table])

        if tables_to_exclude:
            tables_info['excluded_tables'] = tables_to_exclude

        return tables_info