def execute(self, context): mysql_hook = MySqlHook(schema=self.database, mysql_conn_id=self.mysql_conn_id) for rows in self._bq_get_data(): mysql_hook.insert_rows(self.mysql_table, rows, replace=self.replace)
def test_mysql_hook_test_bulk_load(self, client): with MySqlContext(client): records = ("foo", "bar", "baz") import tempfile with tempfile.NamedTemporaryFile() as f: f.write("\n".join(records).encode('utf8')) f.flush() hook = MySqlHook('airflow_db') with closing(hook.get_conn()) as conn: with closing(conn.cursor()) as cursor: cursor.execute( """ CREATE TABLE IF NOT EXISTS test_airflow ( dummy VARCHAR(50) ) """ ) cursor.execute("TRUNCATE TABLE test_airflow") hook.bulk_load("test_airflow", f.name) cursor.execute("SELECT dummy FROM test_airflow") results = tuple(result[0] for result in cursor.fetchall()) assert sorted(results) == sorted(records)
def _call_preoperator(self): mysql = MySqlHook(mysql_conn_id=self.mysql_conn_id) if self.mysql_preoperator: self.log.info("Running MySQL preoperator") mysql.run(self.mysql_preoperator) self.log.info("Inserting rows into MySQL") return mysql
def test_get_conn_from_connection_with_schema(self, mock_connect): conn = Connection(login='******', password='******', host='host', schema='schema') hook = MySqlHook(connection=conn, schema='schema-override') hook.get_conn() mock_connect.assert_called_once_with( user='******', passwd='password-conn', host='host', db='schema-override', port=3306 )
def execute(self, context: Dict[str, str]): hive = HiveCliHook(hive_cli_conn_id=self.hive_cli_conn_id) mysql = MySqlHook(mysql_conn_id=self.mysql_conn_id) self.log.info("Dumping MySQL query results to local file") conn = mysql.get_conn() cursor = conn.cursor() cursor.execute(self.sql) with NamedTemporaryFile("wb") as f: csv_writer = csv.writer( f, delimiter=self.delimiter, quoting=self.quoting, quotechar=self.quotechar, escapechar=self.escapechar, encoding="utf-8", ) field_dict = OrderedDict() for field in cursor.description: field_dict[field[0]] = self.type_map(field[1]) csv_writer.writerows(cursor) f.flush() cursor.close() conn.close() self.log.info("Loading file into Hive") hive.load_file( f.name, self.hive_table, field_dict=field_dict, create=self.create, partition=self.partition, delimiter=self.delimiter, recreate=self.recreate, tblproperties=self.tblproperties, )
def partitions(self): """ Retrieve table partitions """ schema, table = request.args.get("table").split('.') sql = """ SELECT a.PART_NAME, a.CREATE_TIME, c.LOCATION, c.IS_COMPRESSED, c.INPUT_FORMAT, c.OUTPUT_FORMAT FROM PARTITIONS a JOIN TBLS b ON a.TBL_ID = b.TBL_ID JOIN DBS d ON b.DB_ID = d.DB_ID JOIN SDS c ON a.SD_ID = c.SD_ID WHERE b.TBL_NAME like '{table}' AND d.NAME like '{schema}' ORDER BY PART_NAME DESC """.format(table=table, schema=schema) hook = MySqlHook(METASTORE_MYSQL_CONN_ID) df = hook.get_pandas_df(sql) return df.to_html( classes="table table-striped table-bordered table-hover", index=False, na_rep='', )
def objects(self): """ Retrieve objects from TBLS and DBS """ where_clause = '' if DB_ALLOW_LIST: dbs = ",".join(["'" + db + "'" for db in DB_ALLOW_LIST]) where_clause = "AND b.name IN ({})".format(dbs) if DB_DENY_LIST: dbs = ",".join(["'" + db + "'" for db in DB_DENY_LIST]) where_clause = "AND b.name NOT IN ({})".format(dbs) sql = """ SELECT CONCAT(b.NAME, '.', a.TBL_NAME), TBL_TYPE FROM TBLS a JOIN DBS b ON a.DB_ID = b.DB_ID WHERE a.TBL_NAME NOT LIKE '%tmp%' AND a.TBL_NAME NOT LIKE '%temp%' AND b.NAME NOT LIKE '%tmp%' AND b.NAME NOT LIKE '%temp%' {where_clause} LIMIT {LIMIT}; """.format(where_clause=where_clause, LIMIT=TABLE_SELECTOR_LIMIT) hook = MySqlHook(METASTORE_MYSQL_CONN_ID) data = [{ 'id': row[0], 'text': row[0] } for row in hook.get_records(sql)] return json.dumps(data)
def execute(self, context: 'Context') -> None: self.log.info('Executing: %s', self.sql) hook = MySqlHook(mysql_conn_id=self.mysql_conn_id, schema=self.database) hook.run(self.sql, autocommit=self.autocommit, parameters=self.parameters)
def test_mysql_to_hive_verify_loaded_values(self): mysql_table = 'test_mysql_to_hive' hive_table = 'test_mysql_to_hive' hook = MySqlHook() try: minmax = ( 255, 65535, 16777215, 4294967295, 18446744073709551615, -128, -32768, -8388608, -2147483648, -9223372036854775808 ) with hook.get_conn() as conn: conn.execute("DROP TABLE IF EXISTS {}".format(mysql_table)) conn.execute(""" CREATE TABLE {} ( c0 TINYINT UNSIGNED, c1 SMALLINT UNSIGNED, c2 MEDIUMINT UNSIGNED, c3 INT UNSIGNED, c4 BIGINT UNSIGNED, c5 TINYINT, c6 SMALLINT, c7 MEDIUMINT, c8 INT, c9 BIGINT ) """.format(mysql_table)) conn.execute(""" INSERT INTO {} VALUES ( {}, {}, {}, {}, {}, {}, {}, {}, {}, {} ) """.format(mysql_table, *minmax)) op = MySqlToHiveTransferOperator( task_id='test_m2h', hive_cli_conn_id='hive_cli_default', sql="SELECT * FROM {}".format(mysql_table), hive_table=hive_table, recreate=True, delimiter=",", dag=self.dag) op.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE, ignore_ti_state=True) hive_hook = HiveServer2Hook() result = hive_hook.get_records("SELECT * FROM {}".format(hive_table)) self.assertEqual(result[0], minmax) finally: with hook.get_conn() as conn: conn.execute("DROP TABLE IF EXISTS {}".format(mysql_table))
def test_mysql_hook_test_bulk_dump(self): hook = MySqlHook('airflow_db') priv = hook.get_first("SELECT @@global.secure_file_priv") if priv and priv[0]: # Confirm that no error occurs hook.bulk_dump("INFORMATION_SCHEMA.TABLES", os.path.join(priv[0], "TABLES")) else: self.skipTest("Skip test_mysql_hook_test_bulk_load " "since file output is not permitted")
def load_review_data(data): mysql_hook = MySqlHook(mysql_conn_id='sky') engine = mysql_hook.get_sqlalchemy_engine( engine_kwargs={'connect_args': { 'charset': 'utf8mb4' }}) connection = engine.connect() for review_values in data.itertuples(index=False, name=None): replace_into(connection, review_values) connection.close() engine.dispose()
def setUp(self): super().setUp() self.connection = Connection( login='******', password='******', host='host', schema='schema', extra='{"client": "mysql-connector-python"}') self.db_hook = MySqlHook() self.db_hook.get_connection = mock.Mock() self.db_hook.get_connection.return_value = self.connection
def query(self): """Queries mysql and returns a cursor to the results.""" mysql = MySqlHook(mysql_conn_id=self.mysql_conn_id) conn = mysql.get_conn() cursor = conn.cursor() if self.ensure_utc: # Ensure TIMESTAMP results are in UTC tz_query = "SET time_zone = '+00:00'" self.log.info('Executing: %s', tz_query) cursor.execute(tz_query) self.log.info('Executing: %s', self.sql) cursor.execute(self.sql) return cursor
def csv_load_to_db(destination_folder, filename, insert_query_file, by_rows_batch=10000): """ parse csv file and execute query to load into database. Arg: 1. filename = name of csv file 'filename.csv' 2. destination_folder = downloaded files directory 'data/' 3. insert_query = dir or .sql file ,'path/local/query.sql' """ csv_file = open(destination_folder + filename, 'r') sql_file = open(insert_query_file, 'r') sql = sql_file.read() insert_query = sql.split(';')[1] conn = MySqlHook(mysql_conn_id='mysql_localhost').get_conn() cur = conn.cursor() cur.execute('use sales_records_airflow') cur.execute('select count(*) from sales LIMIT 1') row_count = cur.fetchone( )[0] + 1 # add one because we want to exclude header when slicing csv for loop if row_count is 1: print('empty') for row in islice( csv_file, row_count, row_count + by_rows_batch): # start 1, stop 10000 return 10000 rows val = row.rstrip().split(',') dt1 = datetime.strptime(val[5], '%m/%d/%Y').date() dt2 = datetime.strptime(val[7], '%m/%d/%Y').date() val[5] = dt1 val[7] = dt2 params = val cur.execute(query=insert_query, args=params) conn.commit() elif row_count > 1: print('not empty') for row in islice( csv_file, row_count, row_count + by_rows_batch ): # previous rows add 1 start at 10001, stop at 10001+10000 return 10000 rows end at row 20000 val = row.rstrip().split(',') dt1 = datetime.strptime(val[5], '%m/%d/%Y').date() dt2 = datetime.strptime(val[7], '%m/%d/%Y').date() val[5] = dt1 val[7] = dt2 params = val cur.execute(query=insert_query, args=params) conn.commit() elif row_count == 50001: pass conn.close() csv_file.close()
def execute(self, context) -> None: mysql_hook = MySqlHook(mysql_conn_id=self.mysql_conn_id) s3_conn = S3Hook(aws_conn_id=self.aws_conn_id, verify=self.verify) data_df = mysql_hook.get_pandas_df(self.query) self.log.info("Data from MySQL obtained") self._fix_int_dtypes(data_df) with NamedTemporaryFile(mode='r+', suffix='.csv') as tmp_csv: data_df.to_csv(tmp_csv.name, **self.pd_csv_kwargs) s3_conn.load_file(filename=tmp_csv.name, key=self.s3_key, bucket_name=self.s3_bucket) if s3_conn.check_for_key(self.s3_key, bucket_name=self.s3_bucket): file_location = os.path.join(self.s3_bucket, self.s3_key) self.log.info("File saved correctly in %s", file_location)
def mysql_to_pq(source_transform, name_of_dataset='project_four_airflow', by_row_batch=1000): ''' extract mysql database and save into local pq ``tmp/sales-date.pq``. this function take the last rows of bq dataset and compared againts current mysql database to avoid duplication, only extract load new data from mysql to bq. if dataset not exist it will create dataset using name given Args: 1. source_transform = 'path/local/file.pq' 2. by_row_batch = number of row you want to extract ``int`` return: ``str`` of local pq file path ''' client = BigQueryHook(gcp_conn_id='google_cloud_default').get_client() row_id = client.query( 'select id from project_four_airflow.sales order by id desc limit 1') try: for i in row_id: last_row_id = i[0] print(i[0]) except GoogleAPIError: row_id.error_result['reason'] == 'notFound' last_row_id = 0 print('no dataset.table') client.create_dataset(name_of_dataset) print('new dataset, {} created'.format(name_of_dataset)) conn = MySqlHook(mysql_conn_id='mysql_localhost').get_conn() cur = conn.cursor() cur.execute('use sales_records_airflow') cur.execute('select * from sales where id>={} and id<={}'.format( last_row_id + 1, last_row_id + by_row_batch)) list_row = cur.fetchall() rows_of_extracted_mysql = [] for i in list_row: rows_of_extracted_mysql.append(list(i)) print('extracting from mysql') df = pd.DataFrame(rows_of_extracted_mysql, columns=[ 'id', 'region', 'country', 'item_type', 'sales_channel', 'Order Priority', 'order_date', 'order_id', 'ship_date', 'units_sold', 'unit_price', 'unit_cost', 'total_revenue', 'total_cost', 'total_profit' ]) df.to_parquet(source_transform) print('task complete check,', source_transform)
def execute(self, context: 'Context'): vertica = VerticaHook(vertica_conn_id=self.vertica_conn_id) mysql = MySqlHook(mysql_conn_id=self.mysql_conn_id) if self.bulk_load: self._bulk_load_transfer(mysql, vertica) else: self._non_bulk_load_transfer(mysql, vertica) if self.mysql_postoperator: self.log.info("Running MySQL postoperator...") mysql.run(self.mysql_postoperator) self.log.info("Done")
def setUp(self): super().setUp() self.connection = Connection( conn_type='mysql', login='******', password='******', host='host', schema='schema', ) self.db_hook = MySqlHook() self.db_hook.get_connection = mock.Mock() self.db_hook.get_connection.return_value = self.connection
def test_mysql_to_hive_verify_csv_special_char(self): mysql_table = 'test_mysql_to_hive' hive_table = 'test_mysql_to_hive' from airflow.providers.mysql.hooks.mysql import MySqlHook hook = MySqlHook() try: db_record = ( 'c0', '["true"]' ) with hook.get_conn() as conn: conn.execute("DROP TABLE IF EXISTS {}".format(mysql_table)) conn.execute(""" CREATE TABLE {} ( c0 VARCHAR(25), c1 VARCHAR(25) ) """.format(mysql_table)) conn.execute(""" INSERT INTO {} VALUES ( '{}', '{}' ) """.format(mysql_table, *db_record)) from airflow.operators.mysql_to_hive import MySqlToHiveTransfer import unicodecsv as csv op = MySqlToHiveTransfer( task_id='test_m2h', hive_cli_conn_id='hive_cli_default', sql="SELECT * FROM {}".format(mysql_table), hive_table=hive_table, recreate=True, delimiter=",", quoting=csv.QUOTE_NONE, quotechar='', escapechar='@', dag=self.dag) op.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE, ignore_ti_state=True) from airflow.providers.apache.hive.hooks.hive import HiveServer2Hook hive_hook = HiveServer2Hook() result = hive_hook.get_records("SELECT * FROM {}".format(hive_table)) self.assertEqual(result[0], db_record) finally: with hook.get_conn() as conn: conn.execute("DROP TABLE IF EXISTS {}".format(mysql_table))
def tearDown(self): drop_tables = {'test_mysql_to_mysql', 'test_airflow'} with closing(MySqlHook().get_conn()) as conn: for table in drop_tables: # Previous version tried to run execute directly on dbapi call, which was accidentally working with closing(conn.cursor()) as cur: cur.execute(f"DROP TABLE IF EXISTS {table}")
def test_mysql_hook_test_bulk_dump_mock(self, mock_get_conn): mock_execute = mock.MagicMock() mock_get_conn.return_value.cursor.return_value.execute = mock_execute hook = MySqlHook('airflow_db') table = "INFORMATION_SCHEMA.TABLES" tmp_file = "/path/to/output/file" hook.bulk_dump(table, tmp_file) from tests.test_utils.asserts import assert_equal_ignore_multiple_spaces assert mock_execute.call_count == 1 query = """ SELECT * INTO OUTFILE '{tmp_file}' FROM {table} """.format(tmp_file=tmp_file, table=table) assert_equal_ignore_multiple_spaces(self, mock_execute.call_args[0][0], query)
def test_mysql_to_hive_type_conversion(self, mock_load_file): mysql_table = 'test_mysql_to_hive' hook = MySqlHook() try: with hook.get_conn() as conn: conn.execute("DROP TABLE IF EXISTS {}".format(mysql_table)) conn.execute( """ CREATE TABLE {} ( c0 TINYINT, c1 SMALLINT, c2 MEDIUMINT, c3 INT, c4 BIGINT, c5 TIMESTAMP ) """.format( mysql_table ) ) op = MySqlToHiveOperator( task_id='test_m2h', hive_cli_conn_id='hive_cli_default', sql="SELECT * FROM {}".format(mysql_table), hive_table='test_mysql_to_hive', dag=self.dag, ) op.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE, ignore_ti_state=True) assert mock_load_file.call_count == 1 ordered_dict = OrderedDict() ordered_dict["c0"] = "SMALLINT" ordered_dict["c1"] = "INT" ordered_dict["c2"] = "INT" ordered_dict["c3"] = "BIGINT" ordered_dict["c4"] = "DECIMAL(38,0)" ordered_dict["c5"] = "TIMESTAMP" self.assertEqual(mock_load_file.call_args[1]["field_dict"], ordered_dict) finally: with hook.get_conn() as conn: conn.execute("DROP TABLE IF EXISTS {}".format(mysql_table))
def get_database_hook(self, connection: Connection) -> Union[PostgresHook, MySqlHook]: """ Retrieve database hook. This is the actual Postgres or MySQL database hook that uses proxy or connects directly to the Google Cloud SQL database. """ if self.database_type == 'postgres': self.db_hook = PostgresHook(connection=connection, schema=self.database) else: self.db_hook = MySqlHook(connection=connection, schema=self.database) return self.db_hook
def test_mysql_hook_test_bulk_load(self): records = ("foo", "bar", "baz") import tempfile with tempfile.NamedTemporaryFile() as f: f.write("\n".join(records).encode('utf8')) f.flush() hook = MySqlHook('airflow_db') with hook.get_conn() as conn: conn.execute(""" CREATE TABLE IF NOT EXISTS test_airflow ( dummy VARCHAR(50) ) """) conn.execute("TRUNCATE TABLE test_airflow") hook.bulk_load("test_airflow", f.name) conn.execute("SELECT dummy FROM test_airflow") results = tuple(result[0] for result in conn.fetchall()) self.assertEqual(sorted(results), sorted(records))
def index(self): """Create default view""" sql = """ SELECT a.name as db, db_location_uri as location, count(1) as object_count, a.desc as description FROM DBS a JOIN TBLS b ON a.DB_ID = b.DB_ID GROUP BY a.name, db_location_uri, a.desc """ hook = MySqlHook(METASTORE_MYSQL_CONN_ID) df = hook.get_pandas_df(sql) df.db = '<a href="/metastorebrowserview/db/?db=' + df.db + '">' + df.db + '</a>' table = df.to_html( classes="table table-striped table-bordered table-hover", index=False, escape=False, na_rep='', ) return self.render_template("metastore_browser/dbs.html", table=Markup(table))
def execute(self, context) -> None: mysql_hook = MySqlHook(mysql_conn_id=self.mysql_conn_id) s3_conn = S3Hook(aws_conn_id=self.aws_conn_id, verify=self.verify) data_df = mysql_hook.get_pandas_df(self.query) self.log.info("Data from MySQL obtained") self._fix_int_dtypes(data_df) file_options = FILE_OPTIONS_MAP[self.file_format] with NamedTemporaryFile(mode=file_options.mode, suffix=file_options.suffix) as tmp_file: if self.file_format == FILE_FORMAT.CSV: data_df.to_csv(tmp_file.name, **self.pd_kwargs) else: data_df.to_parquet(tmp_file.name, **self.pd_kwargs) s3_conn.load_file(filename=tmp_file.name, key=self.s3_key, bucket_name=self.s3_bucket) if s3_conn.check_for_key(self.s3_key, bucket_name=self.s3_bucket): file_location = os.path.join(self.s3_bucket, self.s3_key) self.log.info("File saved correctly in %s", file_location)
def init_db(): try: hook = MySqlHook() hook.run(CREATE_QUERY) hook.run(LOAD_QUERY) except (OperationalError, ProgrammingError): pass
def check_data(task_instance, create_table_query_file): conn = MySqlHook(mysql_conn_id='mysql_localhost').get_conn() cur = conn.cursor() try: cur.execute('use sales_records_airflow') cur.execute('select count(*) from sales') total_rows = cur.fetchone()[0] task_instance.xcom_push(key='mysql_total_rows', value=total_rows) if type(total_rows) is int: print('appending new data') return 'csv_file_exist' elif total_rows == 50000: print('up to date') return 'check_dataset' except cur.OperationalError: print('sql_file execute') sql_file = open(create_table_query_file, 'r') sql_query = sql_file.read() for query in sql_query.split(';', maxsplit=2): cur.execute('{}'.format(query)) conn.commit() return 'csv_file_not_exist'
def execute(self, context: 'Context') -> None: big_query_hook = BigQueryHook( gcp_conn_id=self.gcp_conn_id, delegate_to=self.delegate_to, location=self.location, impersonation_chain=self.impersonation_chain, ) mysql_hook = MySqlHook(schema=self.database, mysql_conn_id=self.mysql_conn_id) for rows in bigquery_get_data( self.log, self.dataset_id, self.table_id, big_query_hook, self.batch_size, self.selected_fields, ): mysql_hook.insert_rows( table=self.mysql_table, rows=rows, target_fields=self.selected_fields, replace=self.replace, )
def execute(self, context): hive = HiveServer2Hook(hiveserver2_conn_id=self.hiveserver2_conn_id) self.log.info("Extracting data from Hive: %s", self.sql) hive_conf = context_to_airflow_vars(context) if self.hive_conf: hive_conf.update(self.hive_conf) if self.bulk_load: tmp_file = NamedTemporaryFile() hive.to_csv( self.sql, tmp_file.name, delimiter='\t', lineterminator='\n', output_header=False, hive_conf=hive_conf, ) else: hive_results = hive.get_records(self.sql, hive_conf=hive_conf) mysql = MySqlHook(mysql_conn_id=self.mysql_conn_id) if self.mysql_preoperator: self.log.info("Running MySQL preoperator") mysql.run(self.mysql_preoperator) self.log.info("Inserting rows into MySQL") if self.bulk_load: mysql.bulk_load(table=self.mysql_table, tmp_file=tmp_file.name) tmp_file.close() else: mysql.insert_rows(table=self.mysql_table, rows=hive_results) if self.mysql_postoperator: self.log.info("Running MySQL postoperator") mysql.run(self.mysql_postoperator) self.log.info("Done.")