class TemplatesLib(object): """Manage all the templates library interaction""" @inject(loader=AssistedBuilder(FileSystemLoader), environment=AssistedBuilder(Environment)) def __init__(self, loader, environment): """ Initialize the class :param loader: FilesystemLoader :param environment: Environment """ super(TemplatesLib, self).__init__() self.__loader = loader self.__environment = environment def __get_lib(self, base_directory): return self.__environment.build( loader=self.__loader.build(searchpath=base_directory), trim_blocks=True, extensions=['jinja2.ext.do']) def render(self, base_directory, template_file, template_vars): """ Render the given template with the template_vars :param base_directory: string Path where the template can be found :param template_file: string :param template_vars: dict :return: string """ return self.__get_lib(base_directory).get_template( template_file).render(template_vars)
def get_driver(self, injector, driver, db_host='localhost', db_user='******', db_pwd=None, db_port=None, db_name=None, db_schema=None): """ Return a connection to the given database using selected driver. In case given driver is not implemented a NotImplementedError is raised :param injector: Injector :param driver: string :param db_host: string :param db_user: string :param db_pwd: string :param db_port: int :param db_name: string :param db_schema: string :return: Selected driver instance """ try: return injector.get( AssistedBuilder('database_driver_' + driver.lower())).build( db_host=db_host, db_user=db_user, db_pwd=db_pwd, db_port=db_port, db_name=db_name, db_schema=db_schema) except Error: raise NotImplementedError( 'Given driver has not been implemented on SlippinJ')
def test_assisted_builder_uses_bindings(): Interface = Key('Interface') def configure(binder): binder.bind(Interface, to=NeedsAssistance) injector = Injector(configure) builder = injector.get(AssistedBuilder(Interface)) x = builder.build(b=333) assert ((type(x), x.b) == (NeedsAssistance, 333))
def test_assisted_builder_uses_concrete_class_when_specified(): class X(object): pass def configure(binder): # meant only to show that provider isn't called binder.bind(X, to=lambda: 1 / 0) injector = Injector(configure) builder = injector.get(AssistedBuilder(cls=X)) builder.build()
class DefaultConfiguration(object): """Get the default configuration for the workflows""" @inject(deploy_configuration=AssistedBuilder(DeployConfiguration), configuration_parser='configuration_parser') def __init__(self, deploy_configuration, configuration_parser): """ Initialize the class :param deploy_configuration: DeployConfiguration :param configuration_parser: ConfigParser """ super(DefaultConfiguration, self).__init__() self.__deploy_configuration = deploy_configuration self.__configuration_parser = configuration_parser def get(self, environment, arguments, workflow_configuration): """ Get configuration parameters that are common to the workflows :param environment: string :param arguments: Namespace :param workflow_configuration: dict :return: dict """ default_variables = ['hive_metastore_bucket', 'hdfs_deploy_folder'] default_configuration = {} interactive_provided = False deploy_configuration = self.__deploy_configuration.build(environment=environment, configuration_parser=self.__configuration_parser) args = vars(arguments) for variable in default_variables: if variable in args and False != args[variable]: default_configuration[variable] = args[variable] interactive_provided = True elif variable in workflow_configuration: default_configuration[variable] = workflow_configuration[variable] elif deploy_configuration.get(variable): default_configuration[variable] = deploy_configuration.get(variable) else: default_configuration[variable] = raw_input( 'Please, provide the {var_name} value: '.format(var_name=variable.replace('-', ' '))) interactive_provided = True if interactive_provided and 'y' == ( raw_input('Would you like to save the provided information in the config file: [Y/N] ')).lower(): for key in default_configuration: deploy_configuration.set(key, default_configuration[key]) return default_configuration
def test_providers_arent_called_for_dependencies_that_are_already_provided(): def configure(binder): binder.bind(int, to=lambda: 1 / 0) class A(object): @inject(i=int) def __init__(self, i): pass injector = Injector(configure) builder = injector.get(AssistedBuilder(A)) with pytest.raises(ZeroDivisionError): builder.build() builder.build(i=3)
def main(): root_logger = logging.getLogger() root_logger.addHandler(logging.StreamHandler()) root_logger.level = logging.DEBUG parser = ArgumentParser(description='HAL bot') parser.add_argument('--adapter', dest='adapter', default='shell') parser.add_argument('--name', dest='name', default='HAL') arguments = parser.parse_args() injector = Injector([ApplicationModule]) bot_builder = injector.get(AssistedBuilder(Bot)) bot = bot_builder.build(name=arguments.name) _attach_adapter(bot, arguments.adapter) bot.run()
def test_special_interfaces_work_with_auto_bind_disabled(): class InjectMe(object): pass def configure(binder): binder.bind(InjectMe, to=InstanceProvider(InjectMe())) injector = Injector(configure, auto_bind=False) # This line used to fail with: # Traceback (most recent call last): # File "/projects/injector/injector_test.py", line 1171, # in test_auto_bind_disabled_regressions # injector.get(ProviderOf(InjectMe)) # File "/projects/injector/injector.py", line 687, in get # binding = self.binder.get_binding(None, key) # File "/projects/injector/injector.py", line 459, in get_binding # raise UnsatisfiedRequirement(cls, key) # UnsatisfiedRequirement: unsatisfied requirement on # <injector.ProviderOf object at 0x10ff01550> injector.get(ProviderOf(InjectMe)) # This used to fail with an error similar to the ProviderOf one injector.get(AssistedBuilder(cls=InjectMe))
class X(object): @inject(builder=AssistedBuilder(NeedsAssistance)) def __init__(self, builder): self.obj = builder.build(b=234)
def test_assisted_builder_works_when_got_directly_from_injector(): injector = Injector() builder = injector.get(AssistedBuilder(NeedsAssistance)) obj = builder.build(b=123) assert ((obj.a, obj.b) == (str(), 123))
def test_assisted_builder_injection_uses_the_same_binding_key_every_time(): # if we have different BindingKey for every AssistedBuilder(...) we will get memory leak gen_key = lambda: BindingKey(AssistedBuilder(NeedsAssistance)) assert gen_key() == gen_key()
class Oracle(object): """Wrapper to connect to Oracle Servers and get all the metastore information""" @inject(oracle=AssistedBuilder(callable=pyoracle.connect), logger='logger') def __init__(self, oracle, logger, db_host=None, db_user='******', db_name=None, db_schema=None, db_pwd=None, db_port=None): super(Oracle, self).__init__() self.__db_name = db_name self.__db_user = db_user self.__db_schema = db_schema self.__db_dsn = pyoracle.makedsn( host=db_host, port=int(db_port) if None != db_port else 1521, service_name=db_name) self.__conn = oracle.build(user=db_user, password=db_pwd, dsn=self.__db_dsn) if self.__db_schema is not None: cursor = self.__conn.cursor() cursor.execute( "ALTER SESSION SET CURRENT_SCHEMA = {schema}".format( schema=self.__db_schema)) self.__db_connection_string = 'jdbc:oracle:thin:@//' + db_host + ( (':' + db_port) if db_port else '') + ( ('/' + db_name) if db_name else '') self.__illegal_characters = re.compile( r'[\000-\010]|[\013-\014]|[\016-\037]|[\xa1]|[\xc1]|[\xc9]|[\xcd]|[\xd1]|[\xbf]|[\xda]|[\xdc]|[\xe1]|[\xf1]|[\xfa]|[\xf3]' ) self.__logger = logger def __makedict(self, cursor): """ Convert cx_oracle query result to be a dictionary """ cols = [d[0] for d in cursor.description] def createrow(*args): return dict(zip(cols, args)) return createrow def __join_tables_list(self, tables): return ','.join('\'%s\'' % table for table in tables) def __get_table_list(self, table_list_query=False): self.__logger.debug('Getting table list') query_with_db_schema = "= '{schema}'".format(schema=self.__db_schema) query = "SELECT DISTINCT table_name " \ "FROM all_tables WHERE OWNER " \ "{owner} {table_list_query}".format(owner=query_with_db_schema if self.__db_schema else "NOT LIKE '%SYS%' AND OWNER NOT LIKE 'APEX%'AND OWNER NOT LIKE 'XDB'" ,table_list_query=' AND ' + table_list_query if table_list_query else '') cursor = self.__conn.cursor() cursor.execute(query) cursor.rowfactory = self.__makedict(cursor) tablelist = map(lambda x: x['TABLE_NAME'], cursor.fetchall()) self.__logger.debug( 'Found {count} tables'.format(count=cursor.rowcount)) return tablelist def __get_columns_for_tables(self, tables): self.__logger.debug('Getting columns information') query_with_owner = "AND owner = '{schema}'".format( schema=self.__db_schema) info_query = "SELECT table_name, column_name, data_type, data_length, nullable, data_default, data_scale " \ "FROM ALL_TAB_COLUMNS " \ "WHERE table_name IN ({tables}) " \ "{owner}" \ "ORDER BY COLUMN_ID".format(tables=self.__join_tables_list(tables), owner=query_with_owner if self.__db_schema else '') cursor = self.__conn.cursor() cursor.execute(info_query) cursor.rowfactory = self.__makedict(cursor) tables_information = {} for row in cursor.fetchall(): self.__logger.debug('Columns found for table {table}'.format( table=row['TABLE_NAME'])) if not row['TABLE_NAME'] in tables_information: tables_information[row['TABLE_NAME']] = {'columns': []} tables_information[row['TABLE_NAME']]['columns'].append({ 'column_name': row['COLUMN_NAME'], 'data_type': row['DATA_TYPE'].lower(), 'character_maximum_length': row['DATA_LENGTH'], 'is_nullable': row['NULLABLE'], 'column_default': row['DATA_DEFAULT'], }) return tables_information def __get_count_for_tables(self, tables): tables_information = {} cursor = self.__conn.cursor() for table in tables: try: self.__logger.debug( 'Getting count for table {table}'.format(table=table)) info_query = 'SELECT COUNT(*) FROM {table}'.format(table=table) cursor.execute(info_query) tables_information[table] = {'count': cursor.fetchone()[0]} except: self.__logger.debug( 'The count query for table {table} has fail'.format( table=table)) pass return tables_information def __get_top_for_tables(self, tables, top=30): tables_information = {} cursor = self.__conn.cursor() for table in tables: tables_information[table] = {'rows': []} if top > 0: try: self.__logger.debug( 'Getting {top} rows for table {table}'.format( top=top, table=table)) query = 'SELECT * FROM {table} WHERE ROWNUM < {top}'.format( top=top, table=table) cursor.execute(query) for row in cursor.fetchall(): table_row = [] for column in row: try: if type(column) is unicode: column = unicodedata.normalize( 'NFKD', column).encode('iso-8859-1', 'replace') else: column = str(column).decode( 'utf8', 'replace').encode( 'iso-8859-1', 'replace') if self.__illegal_characters.search( column): column = re.sub( self.__illegal_characters, '?', column) if column == 'None': column = 'NULL' except: column = 'Parse_error' table_row.append(column) tables_information[table]['rows'].append(table_row) except pyoracle.ProgrammingError: tables_information[table]['rows'].append( 'Error getting table data {error}'.format( error=pyoracle.ProgrammingError.message)) return tables_information def get_all_tables_info(self, table_list, table_list_query, top_max): """ Return all the tables information reading from the Information Schema database :param table_list: string :param table_list_query: string :param top_max: integer :return: dict """ if table_list: tables = map(lambda x: unicode(x), table_list.split(',')) else: tables = self.__get_table_list(table_list_query) tables_counts = self.__get_count_for_tables(tables) tables_columns = self.__get_columns_for_tables(tables) tables_top = self.__get_top_for_tables(tables, top_max) tables_info = {'tables': {}} for table in tables_counts: tables_info['tables'][table] = {} tables_info['tables'][table].update(tables_columns[table]) tables_info['tables'][table].update(tables_counts[table]) tables_info['tables'][table].update(tables_top[table]) tables_info['db_connection_string'] = self.__db_connection_string return tables_info
class Sqlserver(object): """Wrapper to connect to SQL Servers and get all the metastore information""" @inject(mssql=AssistedBuilder(callable=pymssql.connect), logger='logger') def __init__(self, mssql, logger, db_host=None, db_user='******', db_name=None, db_schema=None, db_pwd=None, db_port=None): """ Initialize the SQLServer driver to get all the tables information :param mssql: Pymssql :param logger: Logger :param db_host: string :param db_user: string :param db_name: string :param db_schema: string :param db_pwd: string :param db_port: int """ super(Sqlserver, self).__init__() self.__db_name = db_name self.__db_schema = db_schema if None != db_schema else 'dbo' self.__conn = mssql.build(server=db_host, user=db_user, password=db_pwd, database=db_name, tds_version='7.0', port=db_port if None != db_port else 1433) self.__db_connection_string = 'jdbc:sqlserver://' + db_host + ( (':' + db_port) if db_port else '') + ( (';DatabaseName=' + db_name) if db_name else '') self.__illegal_characters = re.compile( r'[\000-\010]|[\013-\014]|[\016-\037]|[\xa1]|[\xbf]|[\xc1]|[\xc9]|[\xcd]|[\xd1]|[\xbf]|[\xda]|[\xdc]|[\xe1]|[\xf1]|[\xfa]|[\xf3]' ) self.__logger = logger def __join_tables_list(self, tables): return ','.join('\'%s\'' % table for table in tables) def __get_table_list(self, table_list_query=False): self.__logger.debug('Getting table list') query = 'SELECT table_name FROM information_schema.tables WHERE table_catalog = %(db_name)s and table_schema = %(schema)s {table_list_query}'.format( table_list_query=' AND ' + table_list_query if table_list_query else '') cursor = self.__conn.cursor(as_dict=True) cursor.execute(query, { 'db_name': self.__db_name, 'schema': self.__db_schema }) self.__logger.debug( 'Found {count} tables'.format(count=cursor.rowcount)) return map(lambda x: x['table_name'], cursor.fetchall()) def __get_columns_for_tables(self, tables): self.__logger.debug('Getting columns information') info_query = 'SELECT table_name, column_name, data_type, character_maximum_length, is_nullable, column_default FROM information_schema.columns WHERE table_name IN ({tables}) AND table_catalog=%(db_name)s AND table_schema=%(schema)s'.format( tables=self.__join_tables_list(tables)) cursor = self.__conn.cursor(as_dict=True) cursor.execute(info_query, { 'db_name': self.__db_name, 'schema': self.__db_schema }) tables_information = {} for row in cursor.fetchall(): self.__logger.debug('Columns found for table {table}'.format( table=row['table_name'])) if not row['table_name'] in tables_information: tables_information[row['table_name']] = {'columns': []} tables_information[row['table_name']]['columns'].append({ 'column_name': row['column_name'], 'data_type': row['data_type'], 'character_maximum_length': row['character_maximum_length'], 'is_nullable': row['is_nullable'], 'column_default': row['column_default'], }) return tables_information def __get_count_for_tables(self, tables): tables_information = {} cursor = self.__conn.cursor() for table in tables: try: self.__logger.debug( 'Getting count for table {table}'.format(table=table)) info_query = 'SELECT COUNT(*) FROM [{schema}].[{table}]'.format( table=table, schema=self.__db_schema) cursor.execute(info_query) tables_information[table] = {'count': cursor.fetchone()[0]} except: pass return tables_information def __get_top_for_tables(self, tables, top=30): tables_information = {} cursor = self.__conn.cursor() for table in tables: tables_information[table] = {'rows': []} if top > 0: try: self.__logger.debug( 'Getting {top} rows for table {table}'.format( top=top, table=table)) cursor.execute( 'SELECT TOP {top} * FROM [{schema}].[{table}]'.format( top=top, table=table, schema=self.__db_schema)) for row in cursor.fetchall(): table_row = [] for column in row: try: if type(column) is unicode: column = unicodedata.normalize( 'NFKD', column).encode('iso-8859-1', 'replace') else: column = str(column).decode( 'utf8', 'replace').encode( 'iso-8859-1', 'replace') if self.__illegal_characters.search( column): column = re.sub( self.__illegal_characters, '?', column) if column == 'None': column = 'NULL' except: column = 'Parse_error' table_row.append(column) tables_information[table]['rows'].append(table_row) except pymssql.ProgrammingError: tables_information[table]['rows'].append( 'Error getting table data {error}'.format( error=pymssql.ProgrammingError.message)) return tables_information def get_all_tables_info(self, table_list, table_list_query, top_max): """ Return all the tables information reading from the Information Schema database :param table_list: string :param table_list_query: string :param top_max: integer :return: dict """ if table_list: tables = map(lambda x: unicode(x), table_list.split(',')) else: tables = self.__get_table_list(table_list_query) tables_counts = self.__get_count_for_tables(tables) tables_columns = self.__get_columns_for_tables(tables) tables_top = self.__get_top_for_tables(tables, top_max) tables_info = {'tables': {}} for table in tables_counts: tables_info['tables'][table] = {} tables_info['tables'][table].update(tables_columns[table]) tables_info['tables'][table].update(tables_counts[table]) tables_info['tables'][table].update(tables_top[table]) tables_info['db_connection_string'] = self.__db_connection_string return tables_info
class EmrCluster(object): """Handle all EMR cluster information and configuration""" @inject(aws_emr_client='aws_emr_client', ssh_client=AssistedBuilder(SSHClient), args='args', logger='logger') def __init__(self, aws_emr_client, ssh_client, args, logger): """ Initialize the class :param aws_emr_client: EMR.client :param ssh_client: paramiko.SSHClient :param args: Namespace :param logger: Logging """ super(EmrCluster, self).__init__() self.__aws_emr_client = aws_emr_client self.__ssh_client = ssh_client.build(pem_files_dir=args.pem_dir) self.__cluster_information = {} self.__logger = logger def __get_cluster_environment(self, tags): for tag in tags: if 'environment' == tag['Key'].lower(): return tag['Value'] return False def __get_cluster(self, cluster_id): cluster = self.__aws_emr_client.describe_cluster(ClusterId=cluster_id) instances = self.__aws_emr_client.list_instances( ClusterId=cluster_id, InstanceGroupTypes=['MASTER']) return dict(cluster.items() + instances.items()) def __get_master_ips(self, cluster_information): for instance in cluster_information['Instances']: if cluster_information['Cluster'][ 'MasterPublicDnsName'] == instance['PublicDnsName']: return instance['PublicIpAddress'], instance[ 'PrivateIpAddress'] def get_cluster_information(self, cluster_id): """ Get the cluster information from the AWS API :param cluster_id: string :return: dict """ if not cluster_id in self.__cluster_information: self.__logger.debug( 'Getting information from AWS for cluster {cluster_id}'.format( cluster_id=cluster_id)) cluster_information = self.__get_cluster(cluster_id) cluster_ips = self.__get_master_ips(cluster_information) self.__cluster_information[cluster_id] = { 'public_dns': cluster_information['Cluster']['MasterPublicDnsName'], 'public_ip': cluster_ips[0], 'private_ip': cluster_ips[1], 'environment': self.__get_cluster_environment( cluster_information['Cluster']['Tags']), 'key_name': cluster_information['Cluster']['Ec2InstanceAttributes'] ['Ec2KeyName'] } return self.__cluster_information[cluster_id] def exec_command(self, command, cluster_id, stop_on_error=False): """ Execute given command in the master of the selected cluster :param command: string :param cluster_id: string :param stop_on_error: boolean :return: string """ cluster_information = self.get_cluster_information(cluster_id) self.__logger.debug( 'Executing command {command} in cluster {cluster_id}'.format( command=command, cluster_id=cluster_id)) return self.__ssh_client.exec_command( command, cluster_information['public_dns'], cluster_information['key_name'], stop_on_error) def open_sftp(self, cluster_id): """ Open an SFTP connection to the given cluster :param cluster_id: string :return: SFTP """ cluster_information = self.get_cluster_information(cluster_id) self.__logger.debug( 'Opening SFTP connection to cluster {cluster_id}'.format( cluster_id=cluster_id)) return self.__ssh_client.open_sftp(cluster_information['public_dns'], cluster_information['key_name']) def get_pem_path(self, cluster_id): """ Get the path to the private key file to connect to the cluster :param cluster_id: string :return: string """ cluster_information = self.get_cluster_information(cluster_id) return self.__ssh_client.get_pem_path(cluster_information['key_name'])
def test_assisted_builder_accepts_callables(): injector = Injector() builder = injector.get(AssistedBuilder(callable=lambda x: x * 2)) assert builder.build(x=3) == 6
class X(object): @inject(builder=AssistedBuilder(NeedsAssistance)) def y(self, builder): return builder
class Postgresql(object): """Wrapper to connect to SQL Servers and get all the metastore information""" @inject(postgresql=AssistedBuilder(callable=psycopg2.connect), logger='logger') def __init__(self, postgresql, logger, db_host=None, db_user='******', db_name=None, db_schema=None, db_pwd=None, db_port=None): """ Initialize the Postgresql driver to get all the tables information :param postgresql: Psycopg2 :param logger: Logger :param db_host: string :param db_user: string :param db_name: string :param db_schema: string :param db_pwd: string :param db_port: int """ super(Postgresql, self).__init__() self.__db_name = db_name self.__db_schema = db_schema if None != db_schema else 'public' self.__conn = postgresql.build( host=db_host, user=db_user, password=db_pwd, database=db_name, port=db_port if None != db_port else 5432) self.__column_types = { 'timestamp without time zone': 'timestamp', 'timestamp with time zone': 'timestamp', 'uuid': 'string', 'character': 'string', 'character varying': 'string', 'integer': 'int', 'smallint': 'int', 'text': 'string', 'real': 'double', 'numeric': 'double', 'json': 'string', 'USER-DEFINED': 'string' } self.__illegal_characters = re.compile( r'[\000-\010]|[\013-\014]|[\016-\037]|[\xa1]|[\xbf]|[\xc1]|[\xc9]|[\xcd]|[\xd1]|[\xbf]|[\xda]|[\xdc]|[\xe1]|[\xf1]|[\xfa]|[\xf3]' ) self.__logger = logger def __join_tables_list(self, tables): return ','.join('\'%s\'' % table for table in tables) def __get_valid_column_name(self, column_name): return re.sub("[ ,;{}()\n\t=]", "", column_name) def __get_table_list(self, table_list_query=False): self.__logger.debug('Getting table list') query = 'SELECT table_name FROM information_schema.tables WHERE table_catalog = %(db_name)s and table_schema = %(db_schema)s {table_list_query}'.format( table_list_query=' AND ' + table_list_query if table_list_query else '') cursor = self.__conn.cursor(cursor_factory=psycopg2.extras.DictCursor) cursor.execute(query, { 'db_name': self.__db_name, 'db_schema': self.__db_schema }) self.__logger.debug( 'Found {count} tables'.format(count=cursor.rowcount)) return map(lambda x: x[0], cursor.fetchall()) def __get_tables_to_exclude(self, tables): return self.__get_table_list('table_name NOT IN ({tables})'.format( tables=self.__join_tables_list(tables))) def __get_database_collation(self): self.__logger.debug('Getting database collation') info_query = 'SELECT datcollate FROM pg_database WHERE datname = %(db_name)s' cursor = self.__conn.cursor() cursor.execute(info_query, {'db_name': self.__db_name}) return cursor.fetchone()[0].lower() def __get_columns_for_tables(self, tables): self.__logger.debug('Getting columns information') info_query = 'SELECT table_name, column_name, data_type, character_maximum_length, is_nullable, column_default FROM information_schema.columns WHERE table_name IN ({tables}) AND table_catalog=%(db_name)s AND table_schema=%(db_schema)s'.format( tables=self.__join_tables_list(tables)) cursor = self.__conn.cursor(cursor_factory=psycopg2.extras.DictCursor) cursor.execute(info_query, { 'db_name': self.__db_name, 'db_schema': self.__db_schema }) tables_information = {} for row in cursor.fetchall(): self.__logger.debug('Columns found for table {table}'.format( table=row['table_name'])) if not row['table_name'] in tables_information: tables_information[row['table_name']] = {'columns': []} tables_information[row['table_name']]['columns'].append({ 'source_column_name': row['column_name'], 'column_name': self.__get_valid_column_name(row['column_name']), 'source_data_type': row['data_type'], 'data_type': row['data_type'] if row['data_type'] not in self.__column_types else self.__column_types[row['data_type']], 'character_maximum_length': row['character_maximum_length'], 'is_nullable': row['is_nullable'], 'column_default': row['column_default'], }) return tables_information def __get_count_for_tables(self, tables): tables_information = {} cursor = self.__conn.cursor() for table in tables: try: self.__logger.debug( 'Getting count for table {table}'.format(table=table)) info_query = 'SELECT COUNT(*) FROM {schema}.{table}'.format( table=table, schema=self.__db_schema) cursor.execute(info_query) tables_information[table] = {'count': cursor.fetchone()[0]} except: pass return tables_information def __get_top_for_tables(self, tables, top=30): tables_information = {} utf8_collation = ('utf-8' or 'utf8') in self.__get_database_collation() cursor = self.__conn.cursor() for table in tables: tables_information[table] = {'rows': []} if top > 0: try: self.__logger.debug( 'Getting {top} rows for table {table}'.format( top=top, table=table)) info_query = 'SELECT * FROM {schema}.{table} LIMIT {top}'.format( top=top, table=table, schema=self.__db_schema) cursor.execute(info_query) for row in cursor.fetchall(): table_row = [] for column in row: if not utf8_collation: try: if type(column) is unicode: column = unicodedata.normalize( 'NFKD', column).encode( 'iso-8859-1', 'replace') else: column = str(column).decode( 'utf8', 'replace').encode( 'iso-8859-1', 'replace') if self.__illegal_characters.search( column): column = re.sub( self.__illegal_characters, '?', column) except: column = 'Parse_error' if column == 'None': column = 'NULL' table_row.append(column) tables_information[table]['rows'].append(table_row) except psycopg2.ProgrammingError: tables_information[table]['rows'].append( 'Error getting table data {error}'.format( error=psycopg2.ProgrammingError.message)) return tables_information def get_all_tables_info(self, table_list, table_list_query, top_max): """ Return all the tables information reading from the Information Schema database :param table_list: string :param table_list_query: string :param top_max: integer :return: dict """ tables_to_exclude = {} if table_list: tables = table_list.split(',') tables_to_exclude = self.__get_tables_to_exclude(tables) else: tables = self.__get_table_list(table_list_query) tables_counts = self.__get_count_for_tables(tables) tables_columns = self.__get_columns_for_tables(tables) tables_top = self.__get_top_for_tables(tables, top_max) tables_info = {'tables': {}} for table in tables_counts: tables_info['tables'][table] = {} tables_info['tables'][table].update(tables_columns[table]) tables_info['tables'][table].update(tables_counts[table]) tables_info['tables'][table].update(tables_top[table]) if tables_to_exclude: tables_info['excluded_tables'] = tables_to_exclude return tables_info
def test_assisted_injection_works(self): builder = self.injector.get(AssistedBuilder(self.C)) c = builder.build(noninjectable=5) assert ((type(c.injectable), c.noninjectable) == (self.A, 5))
class Mysql(object): """Wrapper to connect to MySQL Servers and get all the metastore information""" @inject(mysql=AssistedBuilder(callable=pymysql.connect), logger='logger') def __init__(self, mysql, logger, db_host=None, db_user='******', db_name=None, db_schema=None, db_pwd=None, db_port=None): """ Initialize the MySQL driver to get all the tables information :param mysql: Mysql :param logger: Logger :param db_host: string :param db_user: string :param db_name: string :param db_schema: string :param db_pwd: string :param db_port: int """ super(Mysql, self).__init__() self.__db_name = db_name self.__db_schema = db_schema if None != db_schema else 'mysql' self.__conn = mysql.build( host=db_host, user=db_user, password=db_pwd, database=db_name, port=int(db_port) if None != db_port else 3306) self.__column_types = { 'tinyint': 'tinyint', 'boolean': 'boolean', 'smallint': 'smallint', 'mediumint': 'int', 'int': 'int', 'integer': 'int', 'bigint': 'bigint', 'decimal': 'double', 'dec': 'double', 'numeric': 'double', 'fixed': 'double', 'float': 'float', 'double': 'double', 'real': 'double', 'double precision': 'double', 'bit': 'boolean', 'char': 'string', 'varchar': 'string', 'binary': 'binary', 'char byte': 'binary', 'varbinary': 'binary', 'tinyblob': 'binary', 'blob': 'binary', 'mediumblob': 'binary', 'tinytext': 'string', 'text': 'string', 'mediumtext': 'string', 'longtext': 'string', 'enum': 'string', 'date': 'timestamp', 'time': 'timestamp', 'datetime': 'timestamp', 'timestamp': 'timestamp' } self.__illegal_characters = re.compile( r'[\000-\010]|[\013-\014]|[\016-\037]|[\xa1]|[\xbf]|[\xc1]|[\xc9]|[\xcd]|[\xd1]|[\xbf]|[\xda]|[\xdc]|[\xe1]|[\xf1]|[\xfa]|[\xf3]' ) self.__logger = logger def __join_tables_list(self, tables): return ','.join('\'%s\'' % table for table in tables) def __get_valid_column_name(self, column_name): return re.sub("[ ,;{}()\n\t=]", "", column_name) def __get_table_list(self, table_list_query=False): self.__logger.debug('Getting table list') query = 'SELECT table_name FROM information_schema.tables WHERE table_schema = %(schema)s {table_list_query}'.format( table_list_query=' AND ' + table_list_query if table_list_query else '') cursor = self.__conn.cursor(pymysql.cursors.DictCursor) cursor.execute(query, { 'db_name': self.__db_name, 'schema': self.__db_name }) self.__logger.debug( 'Found {count} tables'.format(count=cursor.rowcount)) return map(lambda x: x['table_name'], cursor.fetchall()) def __get_tables_to_exclude(self, tables): return self.__get_table_list('table_name NOT IN ({tables})'.format( tables=self.__join_tables_list(tables))) def __get_columns_for_tables(self, tables): self.__logger.debug('Getting columns information') info_query = 'SELECT table_name, column_name, data_type, character_maximum_length, is_nullable, ' \ 'column_default FROM information_schema.columns WHERE table_name IN ({tables}) AND table_schema=%(schema)s'.format(tables=self.__join_tables_list(tables)) cursor = self.__conn.cursor(pymysql.cursors.DictCursor) cursor.execute(info_query, { 'db_name': self.__db_name, 'schema': self.__db_name }) tables_information = {} for row in cursor.fetchall(): self.__logger.debug('Columns found for table {table}'.format( table=row['table_name'])) if not row['table_name'] in tables_information: tables_information[row['table_name']] = {'columns': []} tables_information[row['table_name']]['columns'].append({ 'source_column_name': row['column_name'], 'column_name': self.__get_valid_column_name(row['column_name']), 'source_data_type': row['data_type'], 'data_type': row['data_type'] if row['data_type'] not in self.__column_types else self.__column_types[row['data_type']], 'character_maximum_length': row['character_maximum_length'], 'is_nullable': row['is_nullable'], 'column_default': row['column_default'], }) return tables_information def __get_count_for_tables(self, tables): tables_information = {} cursor = self.__conn.cursor() for table in tables: try: self.__logger.debug( 'Getting count for table {table}'.format(table=table)) info_query = 'SELECT COUNT(*) FROM {schema}.{table}'.format( table=table, schema=self.__db_name) cursor.execute(info_query) tables_information[table] = {'count': cursor.fetchone()[0]} except: self.__logger.debug( 'The count query for table {table} has fail'.format( table=table)) pass return tables_information def __get_top_for_tables(self, tables, top=30): tables_information = {} cursor = self.__conn.cursor() for table in tables: tables_information[table] = {'rows': []} if top > 0: try: self.__logger.debug( 'Getting {top} rows for table {table}'.format( top=top, table=table)) cursor.execute( 'SELECT * FROM {schema}.{table} LIMIT {top}'.format( top=top, table=table, schema=self.__db_name)) for row in cursor.fetchall(): table_row = [] for column in row: try: if type(column) is unicode: column = unicodedata.normalize( 'NFKD', column).encode('iso-8859-1', 'replace') else: column = str(column).decode( 'utf8', 'replace').encode( 'iso-8859-1', 'replace') if self.__illegal_characters.search( column): column = re.sub( self.__illegal_characters, '?', column) if column == 'None': column = 'NULL' except: column = 'Parse_error' table_row.append(column) tables_information[table]['rows'].append(table_row) except pymysql.ProgrammingError: tables_information[table]['rows'].append( 'Error getting table data {error}'.format( error=pymysql.ProgrammingError.message)) return tables_information def get_all_tables_info(self, table_list, table_list_query, top_max): """ Return all the tables information reading from the Information Schema database :param table_list: string :param table_list_query: string :param top_max: integer :return: dict """ tables_to_exclude = {} if table_list: tables = map(lambda x: unicode(x), table_list.split(',')) tables_to_exclude = self.__get_tables_to_exclude(tables) else: tables = self.__get_table_list(table_list_query) tables_counts = self.__get_count_for_tables(tables) tables_columns = self.__get_columns_for_tables(tables) tables_top = self.__get_top_for_tables(tables, top_max) tables_info = {'tables': {}} for table in tables_counts: tables_info['tables'][table] = {} tables_info['tables'][table].update(tables_columns[table]) tables_info['tables'][table].update(tables_counts[table]) tables_info['tables'][table].update(tables_top[table]) if tables_to_exclude: tables_info['excluded_tables'] = tables_to_exclude return tables_info