def query(self): """Queries mysql and returns a cursor to the results.""" mysql = MySqlHook(mysql_conn_id=self.mysql_conn_id) conn = mysql.get_conn() cursor = conn.cursor() if self.ensure_utc: # Ensure TIMESTAMP results are in UTC tz_query = "SET time_zone = '+00:00'" self.log.info('Executing: %s', tz_query) cursor.execute(tz_query) self.log.info('Executing: %s', self.sql) cursor.execute(self.sql) return cursor
def execute(self, context): presto = PrestoHook(presto_conn_id=self.presto_conn_id) self.log.info("Extracting data from Presto: %s", self.sql) results = presto.get_records(self.sql) mysql = MySqlHook(mysql_conn_id=self.mysql_conn_id) if self.mysql_preoperator: self.log.info("Running MySQL preoperator") self.log.info(self.mysql_preoperator) mysql.run(self.mysql_preoperator) self.log.info("Inserting rows into MySQL") mysql.insert_rows(table=self.mysql_table, rows=results)
def execute(self, context: Dict) -> None: trino = TrinoHook(trino_conn_id=self.trino_conn_id) self.log.info("Extracting data from Trino: %s", self.sql) results = trino.get_records(self.sql) mysql = MySqlHook(mysql_conn_id=self.mysql_conn_id) if self.mysql_preoperator: self.log.info("Running MySQL preoperator") self.log.info(self.mysql_preoperator) mysql.run(self.mysql_preoperator) self.log.info("Inserting rows into MySQL") mysql.insert_rows(table=self.mysql_table, rows=results)
def load_review_data(data): mysql_hook = MySqlHook(mysql_conn_id='sky') engine = mysql_hook.get_sqlalchemy_engine( engine_kwargs={'connect_args': { 'charset': 'utf8mb4' }}) connection = engine.connect() for review_values in data.itertuples(index=False, name=None): replace_into(connection, review_values) connection.close() engine.dispose()
def get_database_hook( self, connection: Connection) -> Union[PostgresHook, MySqlHook]: """ Retrieve database hook. This is the actual Postgres or MySQL database hook that uses proxy or connects directly to the Google Cloud SQL database. """ if self.database_type == 'postgres': self.db_hook = PostgresHook(connection=connection, schema=self.database) else: self.db_hook = MySqlHook(connection=connection, schema=self.database) return self.db_hook
def execute(self, context: 'Context'): vertica = VerticaHook(vertica_conn_id=self.vertica_conn_id) mysql = MySqlHook(mysql_conn_id=self.mysql_conn_id) if self.bulk_load: self._bulk_load_transfer(mysql, vertica) else: self._non_bulk_load_transfer(mysql, vertica) if self.mysql_postoperator: self.log.info("Running MySQL postoperator...") mysql.run(self.mysql_postoperator) self.log.info("Done")
def execute(self, context) -> None: mysql_hook = MySqlHook(mysql_conn_id=self.mysql_conn_id) s3_conn = S3Hook(aws_conn_id=self.aws_conn_id, verify=self.verify) data_df = mysql_hook.get_pandas_df(self.query) self.log.info("Data from MySQL obtained") self._fix_int_dtypes(data_df) with NamedTemporaryFile(mode='r+', suffix='.csv') as tmp_csv: data_df.to_csv(tmp_csv.name, **self.pd_csv_kwargs) s3_conn.load_file(filename=tmp_csv.name, key=self.s3_key, bucket_name=self.s3_bucket) if s3_conn.check_for_key(self.s3_key, bucket_name=self.s3_bucket): file_location = os.path.join(self.s3_bucket, self.s3_key) self.log.info("File saved correctly in %s", file_location)
def setUp(self): super().setUp() self.connection = Connection( login='******', password='******', host='host', schema='schema', extra='{"client": "mysql-connector-python"}', ) self.db_hook = MySqlHook() self.db_hook.get_connection = mock.Mock() self.db_hook.get_connection.return_value = self.connection
def mysql_to_pq(source_transform, name_of_dataset='project_four_airflow', by_row_batch=1000): ''' extract mysql database and save into local pq ``tmp/sales-date.pq``. this function take the last rows of bq dataset and compared againts current mysql database to avoid duplication, only extract load new data from mysql to bq. if dataset not exist it will create dataset using name given Args: 1. source_transform = 'path/local/file.pq' 2. by_row_batch = number of row you want to extract ``int`` return: ``str`` of local pq file path ''' client = BigQueryHook(gcp_conn_id='google_cloud_default').get_client() row_id = client.query( 'select id from project_four_airflow.sales order by id desc limit 1') try: for i in row_id: last_row_id = i[0] print(i[0]) except GoogleAPIError: row_id.error_result['reason'] == 'notFound' last_row_id = 0 print('no dataset.table') client.create_dataset(name_of_dataset) print('new dataset, {} created'.format(name_of_dataset)) conn = MySqlHook(mysql_conn_id='mysql_localhost').get_conn() cur = conn.cursor() cur.execute('use sales_records_airflow') cur.execute('select * from sales where id>={} and id<={}'.format( last_row_id + 1, last_row_id + by_row_batch)) list_row = cur.fetchall() rows_of_extracted_mysql = [] for i in list_row: rows_of_extracted_mysql.append(list(i)) print('extracting from mysql') df = pd.DataFrame(rows_of_extracted_mysql, columns=[ 'id', 'region', 'country', 'item_type', 'sales_channel', 'Order Priority', 'order_date', 'order_id', 'ship_date', 'units_sold', 'unit_price', 'unit_cost', 'total_revenue', 'total_cost', 'total_profit' ]) df.to_parquet(source_transform) print('task complete check,', source_transform)
def setUp(self): super().setUp() self.connection = Connection( conn_type='mysql', login='******', password='******', host='host', schema='schema', ) self.db_hook = MySqlHook() self.db_hook.get_connection = mock.Mock() self.db_hook.get_connection.return_value = self.connection
def test_mysql_hook_test_bulk_dump(self, client): with MySqlContext(client): hook = MySqlHook('airflow_db') priv = hook.get_first("SELECT @@global.secure_file_priv") # Use random names to allow re-running if priv and priv[0]: # Confirm that no error occurs hook.bulk_dump( "INFORMATION_SCHEMA.TABLES", os.path.join(priv[0], "TABLES_{}-{}".format(client, uuid.uuid1())), ) elif priv == ("",): hook.bulk_dump("INFORMATION_SCHEMA.TABLES", "TABLES_{}_{}".format(client, uuid.uuid1())) else: self.skipTest("Skip test_mysql_hook_test_bulk_load " "since file output is not permitted")
def test_mysql_to_hive_verify_csv_special_char(self): mysql_table = 'test_mysql_to_hive' hive_table = 'test_mysql_to_hive' from airflow.providers.mysql.hooks.mysql import MySqlHook hook = MySqlHook() try: db_record = ( 'c0', '["true"]' ) with hook.get_conn() as conn: conn.execute("DROP TABLE IF EXISTS {}".format(mysql_table)) conn.execute(""" CREATE TABLE {} ( c0 VARCHAR(25), c1 VARCHAR(25) ) """.format(mysql_table)) conn.execute(""" INSERT INTO {} VALUES ( '{}', '{}' ) """.format(mysql_table, *db_record)) from airflow.operators.mysql_to_hive import MySqlToHiveTransfer import unicodecsv as csv op = MySqlToHiveTransfer( task_id='test_m2h', hive_cli_conn_id='hive_cli_default', sql="SELECT * FROM {}".format(mysql_table), hive_table=hive_table, recreate=True, delimiter=",", quoting=csv.QUOTE_NONE, quotechar='', escapechar='@', dag=self.dag) op.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE, ignore_ti_state=True) from airflow.providers.apache.hive.hooks.hive import HiveServer2Hook hive_hook = HiveServer2Hook() result = hive_hook.get_records("SELECT * FROM {}".format(hive_table)) self.assertEqual(result[0], db_record) finally: with hook.get_conn() as conn: conn.execute("DROP TABLE IF EXISTS {}".format(mysql_table))
def test_mysql_hook_test_bulk_dump_mock(self, mock_get_conn): mock_execute = mock.MagicMock() mock_get_conn.return_value.cursor.return_value.execute = mock_execute hook = MySqlHook('airflow_db') table = "INFORMATION_SCHEMA.TABLES" tmp_file = "/path/to/output/file" hook.bulk_dump(table, tmp_file) from tests.test_utils.asserts import assert_equal_ignore_multiple_spaces assert mock_execute.call_count == 1 query = """ SELECT * INTO OUTFILE '{tmp_file}' FROM {table} """.format(tmp_file=tmp_file, table=table) assert_equal_ignore_multiple_spaces(self, mock_execute.call_args[0][0], query)
def test_mysql_to_hive_type_conversion(self, mock_load_file): mysql_table = 'test_mysql_to_hive' hook = MySqlHook() try: with hook.get_conn() as conn: conn.execute("DROP TABLE IF EXISTS {}".format(mysql_table)) conn.execute( """ CREATE TABLE {} ( c0 TINYINT, c1 SMALLINT, c2 MEDIUMINT, c3 INT, c4 BIGINT, c5 TIMESTAMP ) """.format( mysql_table ) ) op = MySqlToHiveOperator( task_id='test_m2h', hive_cli_conn_id='hive_cli_default', sql="SELECT * FROM {}".format(mysql_table), hive_table='test_mysql_to_hive', dag=self.dag, ) op.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE, ignore_ti_state=True) assert mock_load_file.call_count == 1 ordered_dict = OrderedDict() ordered_dict["c0"] = "SMALLINT" ordered_dict["c1"] = "INT" ordered_dict["c2"] = "INT" ordered_dict["c3"] = "BIGINT" ordered_dict["c4"] = "DECIMAL(38,0)" ordered_dict["c5"] = "TIMESTAMP" self.assertEqual(mock_load_file.call_args[1]["field_dict"], ordered_dict) finally: with hook.get_conn() as conn: conn.execute("DROP TABLE IF EXISTS {}".format(mysql_table))
def test_mysql_hook_test_bulk_load(self): records = ("foo", "bar", "baz") import tempfile with tempfile.NamedTemporaryFile() as f: f.write("\n".join(records).encode('utf8')) f.flush() hook = MySqlHook('airflow_db') with hook.get_conn() as conn: conn.execute(""" CREATE TABLE IF NOT EXISTS test_airflow ( dummy VARCHAR(50) ) """) conn.execute("TRUNCATE TABLE test_airflow") hook.bulk_load("test_airflow", f.name) conn.execute("SELECT dummy FROM test_airflow") results = tuple(result[0] for result in conn.fetchall()) self.assertEqual(sorted(results), sorted(records))
def execute(self, context) -> None: mysql_hook = MySqlHook(mysql_conn_id=self.mysql_conn_id) s3_conn = S3Hook(aws_conn_id=self.aws_conn_id, verify=self.verify) data_df = mysql_hook.get_pandas_df(self.query) self.log.info("Data from MySQL obtained") self._fix_int_dtypes(data_df) file_options = FILE_OPTIONS_MAP[self.file_format] with NamedTemporaryFile(mode=file_options.mode, suffix=file_options.suffix) as tmp_file: if self.file_format == FILE_FORMAT.CSV: data_df.to_csv(tmp_file.name, **self.pd_kwargs) else: data_df.to_parquet(tmp_file.name, **self.pd_kwargs) s3_conn.load_file(filename=tmp_file.name, key=self.s3_key, bucket_name=self.s3_bucket) if s3_conn.check_for_key(self.s3_key, bucket_name=self.s3_bucket): file_location = os.path.join(self.s3_bucket, self.s3_key) self.log.info("File saved correctly in %s", file_location)
def index(self): """Create default view""" sql = """ SELECT a.name as db, db_location_uri as location, count(1) as object_count, a.desc as description FROM DBS a JOIN TBLS b ON a.DB_ID = b.DB_ID GROUP BY a.name, db_location_uri, a.desc """ hook = MySqlHook(METASTORE_MYSQL_CONN_ID) df = hook.get_pandas_df(sql) df.db = '<a href="/metastorebrowserview/db/?db=' + df.db + '">' + df.db + '</a>' table = df.to_html( classes="table table-striped table-bordered table-hover", index=False, escape=False, na_rep='', ) return self.render_template("metastore_browser/dbs.html", table=Markup(table))
def check_data(task_instance, create_table_query_file): conn = MySqlHook(mysql_conn_id='mysql_localhost').get_conn() cur = conn.cursor() try: cur.execute('use sales_records_airflow') cur.execute('select count(*) from sales') total_rows = cur.fetchone()[0] task_instance.xcom_push(key='mysql_total_rows', value=total_rows) if type(total_rows) is int: print('appending new data') return 'csv_file_exist' elif total_rows == 50000: print('up to date') return 'check_dataset' except cur.OperationalError: print('sql_file execute') sql_file = open(create_table_query_file, 'r') sql_query = sql_file.read() for query in sql_query.split(';', maxsplit=2): cur.execute('{}'.format(query)) conn.commit() return 'csv_file_not_exist'
def execute(self, context: 'Context') -> None: big_query_hook = BigQueryHook( gcp_conn_id=self.gcp_conn_id, delegate_to=self.delegate_to, location=self.location, impersonation_chain=self.impersonation_chain, ) mysql_hook = MySqlHook(schema=self.database, mysql_conn_id=self.mysql_conn_id) for rows in bigquery_get_data( self.log, self.dataset_id, self.table_id, big_query_hook, self.batch_size, self.selected_fields, ): mysql_hook.insert_rows( table=self.mysql_table, rows=rows, target_fields=self.selected_fields, replace=self.replace, )
def execute(self, context): hive = HiveServer2Hook(hiveserver2_conn_id=self.hiveserver2_conn_id) self.log.info("Extracting data from Hive: %s", self.sql) hive_conf = context_to_airflow_vars(context) if self.hive_conf: hive_conf.update(self.hive_conf) if self.bulk_load: tmp_file = NamedTemporaryFile() hive.to_csv( self.sql, tmp_file.name, delimiter='\t', lineterminator='\n', output_header=False, hive_conf=hive_conf, ) else: hive_results = hive.get_records(self.sql, hive_conf=hive_conf) mysql = MySqlHook(mysql_conn_id=self.mysql_conn_id) if self.mysql_preoperator: self.log.info("Running MySQL preoperator") mysql.run(self.mysql_preoperator) self.log.info("Inserting rows into MySQL") if self.bulk_load: mysql.bulk_load(table=self.mysql_table, tmp_file=tmp_file.name) tmp_file.close() else: mysql.insert_rows(table=self.mysql_table, rows=hive_results) if self.mysql_postoperator: self.log.info("Running MySQL postoperator") mysql.run(self.mysql_postoperator) self.log.info("Done.")
def execute(self, context: dict) -> None: """ Executes the transfer operation from S3 to MySQL. :param context: The context that is being provided when executing. :type context: dict """ self.log.info('Loading %s to MySql table %s...', self.s3_source_key, self.mysql_table) s3_hook = S3Hook(aws_conn_id=self.aws_conn_id) file = s3_hook.download_file(key=self.s3_source_key) try: mysql = MySqlHook(mysql_conn_id=self.mysql_conn_id) mysql.bulk_load_custom( table=self.mysql_table, tmp_file=file, duplicate_key_handling=self.mysql_duplicate_key_handling, extra_options=self.mysql_extra_options) finally: # Remove file downloaded from s3 to be idempotent. os.remove(file)
def test_mysql_hook_test_bulk_load(self, client): with MySqlContext(client): records = ("foo", "bar", "baz") import tempfile with tempfile.NamedTemporaryFile() as f: f.write("\n".join(records).encode('utf8')) f.flush() hook = MySqlHook('airflow_db') with closing(hook.get_conn()) as conn: with closing(conn.cursor()) as cursor: cursor.execute(""" CREATE TABLE IF NOT EXISTS test_airflow ( dummy VARCHAR(50) ) """) cursor.execute("TRUNCATE TABLE test_airflow") hook.bulk_load("test_airflow", f.name) cursor.execute("SELECT dummy FROM test_airflow") results = tuple(result[0] for result in cursor.fetchall()) assert sorted(results) == sorted(records)
def execute(self, context): """Establish connections to both MySQL & PostgreSQL databases, open cursor and begin processing query, loading chunks of rows into PostgreSQL. Repeat loading chunks until all rows processed for query. """ source = MySqlHook(mysql_conn_id=self.mysql_conn_id) target = PostgresHook(postgres_conn_id=self.postgres_conn_id) with closing(source.get_conn()) as conn: with closing(conn.cursor()) as cursor: cursor.execute(self.sql, self.params) target_fields = [x[0] for x in cursor.description] row_count = 0 rows = cursor.fetchmany(self.rows_chunk) while len(rows) > 0: row_count += len(rows) target.insert_rows( self.postgres_table, rows, target_fields=target_fields, commit_every=self.rows_chunk, ) rows = cursor.fetchmany(self.rows_chunk) self.log.info( f"{row_count} row(s) inserted into {self.postgres_table}.")
def execute(self, context: Optional[Dict[str, Any]] = None) -> None: metastore = HiveMetastoreHook(metastore_conn_id=self.metastore_conn_id) table = metastore.get_table(table_name=self.table) field_types = {col.name: col.type for col in table.sd.cols} exprs: Any = {('', 'count'): 'COUNT(*)'} for col, col_type in list(field_types.items()): if self.assignment_func: assign_exprs = self.assignment_func(col, col_type) if assign_exprs is None: assign_exprs = self.get_default_exprs(col, col_type) else: assign_exprs = self.get_default_exprs(col, col_type) exprs.update(assign_exprs) exprs.update(self.extra_exprs) exprs = OrderedDict(exprs) exprs_str = ",\n ".join([v + " AS " + k[0] + '__' + k[1] for k, v in exprs.items()]) where_clause_ = ["{} = '{}'".format(k, v) for k, v in self.partition.items()] where_clause = " AND\n ".join(where_clause_) sql = "SELECT {exprs_str} FROM {table} WHERE {where_clause};".format( exprs_str=exprs_str, table=self.table, where_clause=where_clause ) presto = PrestoHook(presto_conn_id=self.presto_conn_id) self.log.info('Executing SQL check: %s', sql) row = presto.get_first(hql=sql) self.log.info("Record: %s", row) if not row: raise AirflowException("The query returned None") part_json = json.dumps(self.partition, sort_keys=True) self.log.info("Deleting rows from previous runs if they exist") mysql = MySqlHook(self.mysql_conn_id) sql = """ SELECT 1 FROM hive_stats WHERE table_name='{table}' AND partition_repr='{part_json}' AND dttm='{dttm}' LIMIT 1; """.format( table=self.table, part_json=part_json, dttm=self.dttm ) if mysql.get_records(sql): sql = """ DELETE FROM hive_stats WHERE table_name='{table}' AND partition_repr='{part_json}' AND dttm='{dttm}'; """.format( table=self.table, part_json=part_json, dttm=self.dttm ) mysql.run(sql) self.log.info("Pivoting and loading cells into the Airflow db") rows = [ (self.ds, self.dttm, self.table, part_json) + (r[0][0], r[0][1], r[1]) for r in zip(exprs, row) ] mysql.insert_rows( table='hive_stats', rows=rows, target_fields=[ 'ds', 'dttm', 'table_name', 'partition_repr', 'col', 'metric', 'value', ], )
def tearDown(self): drop_tables = {'test_mysql_to_mysql', 'test_airflow'} with MySqlHook().get_conn() as conn: for table in drop_tables: conn.execute(f"DROP TABLE IF EXISTS {table}")
def tearDown(self): drop_tables = {'test_mysql_to_mysql', 'test_airflow'} with closing(MySqlHook().get_conn()) as conn: with closing(conn.cursor()) as cursor: for table in drop_tables: cursor.execute(f"DROP TABLE IF EXISTS {table}")
def setUp(self): args = {'owner': 'airflow', 'start_date': DEFAULT_DATE} dag = DAG(TEST_DAG_ID, default_args=args) self.dag = dag rows = [ (1880, "John", 0.081541, "boy"), (1880, "William", 0.080511, "boy"), (1880, "James", 0.050057, "boy"), (1880, "Charles", 0.045167, "boy"), (1880, "George", 0.043292, "boy"), (1880, "Frank", 0.02738, "boy"), (1880, "Joseph", 0.022229, "boy"), (1880, "Thomas", 0.021401, "boy"), (1880, "Henry", 0.020641, "boy"), (1880, "Robert", 0.020404, "boy"), (1880, "Edward", 0.019965, "boy"), (1880, "Harry", 0.018175, "boy"), (1880, "Walter", 0.014822, "boy"), (1880, "Arthur", 0.013504, "boy"), (1880, "Fred", 0.013251, "boy"), (1880, "Albert", 0.012609, "boy"), (1880, "Samuel", 0.008648, "boy"), (1880, "David", 0.007339, "boy"), (1880, "Louis", 0.006993, "boy"), (1880, "Joe", 0.006174, "boy"), (1880, "Charlie", 0.006165, "boy"), (1880, "Clarence", 0.006165, "boy"), (1880, "Richard", 0.006148, "boy"), (1880, "Andrew", 0.005439, "boy"), (1880, "Daniel", 0.00543, "boy"), (1880, "Ernest", 0.005194, "boy"), (1880, "Will", 0.004966, "boy"), (1880, "Jesse", 0.004805, "boy"), (1880, "Oscar", 0.004594, "boy"), (1880, "Lewis", 0.004366, "boy"), (1880, "Peter", 0.004189, "boy"), (1880, "Benjamin", 0.004138, "boy"), (1880, "Frederick", 0.004079, "boy"), (1880, "Willie", 0.00402, "boy"), (1880, "Alfred", 0.003961, "boy"), (1880, "Sam", 0.00386, "boy"), (1880, "Roy", 0.003716, "boy"), (1880, "Herbert", 0.003581, "boy"), (1880, "Jacob", 0.003412, "boy"), (1880, "Tom", 0.00337, "boy"), (1880, "Elmer", 0.00315, "boy"), (1880, "Carl", 0.003142, "boy"), (1880, "Lee", 0.003049, "boy"), (1880, "Howard", 0.003015, "boy"), (1880, "Martin", 0.003015, "boy"), (1880, "Michael", 0.00299, "boy"), (1880, "Bert", 0.002939, "boy"), (1880, "Herman", 0.002931, "boy"), (1880, "Jim", 0.002914, "boy"), (1880, "Francis", 0.002905, "boy"), (1880, "Harvey", 0.002905, "boy"), (1880, "Earl", 0.002829, "boy"), (1880, "Eugene", 0.00277, "boy"), ] self.env_vars = { 'AIRFLOW_CTX_DAG_ID': 'test_dag_id', 'AIRFLOW_CTX_TASK_ID': 'test_task_id', 'AIRFLOW_CTX_EXECUTION_DATE': '2015-01-01T00:00:00+00:00', 'AIRFLOW_CTX_DAG_RUN_ID': '55', 'AIRFLOW_CTX_DAG_OWNER': 'airflow', 'AIRFLOW_CTX_DAG_EMAIL': '*****@*****.**', } with MySqlHook().get_conn() as cur: cur.execute(''' CREATE TABLE IF NOT EXISTS baby_names ( org_year integer(4), baby_name VARCHAR(25), rate FLOAT(7,6), sex VARCHAR(4) ) ''') for row in rows: cur.execute("INSERT INTO baby_names VALUES(%s, %s, %s, %s);", row)
def execute(self, context): vertica = VerticaHook(vertica_conn_id=self.vertica_conn_id) mysql = MySqlHook(mysql_conn_id=self.mysql_conn_id) tmpfile = None result = None selected_columns = [] count = 0 with closing(vertica.get_conn()) as conn: with closing(conn.cursor()) as cursor: cursor.execute(self.sql) selected_columns = [d.name for d in cursor.description] if self.bulk_load: tmpfile = NamedTemporaryFile("w") self.log.info( "Selecting rows from Vertica to local file %s...", tmpfile.name) self.log.info(self.sql) csv_writer = csv.writer(tmpfile, delimiter='\t', encoding='utf-8') for row in cursor.iterate(): csv_writer.writerow(row) count += 1 tmpfile.flush() else: self.log.info("Selecting rows from Vertica...") self.log.info(self.sql) result = cursor.fetchall() count = len(result) self.log.info("Selected rows from Vertica %s", count) if self.mysql_preoperator: self.log.info("Running MySQL preoperator...") mysql.run(self.mysql_preoperator) try: if self.bulk_load: self.log.info("Bulk inserting rows into MySQL...") with closing(mysql.get_conn()) as conn: with closing(conn.cursor()) as cursor: cursor.execute( "LOAD DATA LOCAL INFILE '%s' INTO " "TABLE %s LINES TERMINATED BY '\r\n' (%s)" % (tmpfile.name, self.mysql_table, ", ".join(selected_columns))) conn.commit() tmpfile.close() else: self.log.info("Inserting rows into MySQL...") mysql.insert_rows(table=self.mysql_table, rows=result, target_fields=selected_columns) self.log.info("Inserted rows into MySQL %s", count) except (MySQLdb.Error, MySQLdb.Warning): # pylint: disable=no-member self.log.info("Inserted rows into MySQL 0") raise if self.mysql_postoperator: self.log.info("Running MySQL postoperator...") mysql.run(self.mysql_postoperator) self.log.info("Done")
def test_mysql_to_hive_verify_loaded_values(self, mock_popen, mock_temp_dir): mock_subprocess = MockSubProcess() mock_popen.return_value = mock_subprocess mock_temp_dir.return_value = "test_mysql_to_hive" mysql_table = 'test_mysql_to_hive' hive_table = 'test_mysql_to_hive' hook = MySqlHook() try: minmax = ( 255, 65535, 16777215, 4294967295, 18446744073709551615, -128, -32768, -8388608, -2147483648, -9223372036854775808, ) with hook.get_conn() as conn: conn.execute(f"DROP TABLE IF EXISTS {mysql_table}") conn.execute(""" CREATE TABLE {} ( c0 TINYINT UNSIGNED, c1 SMALLINT UNSIGNED, c2 MEDIUMINT UNSIGNED, c3 INT UNSIGNED, c4 BIGINT UNSIGNED, c5 TINYINT, c6 SMALLINT, c7 MEDIUMINT, c8 INT, c9 BIGINT ) """.format(mysql_table)) conn.execute(""" INSERT INTO {} VALUES ( {}, {}, {}, {}, {}, {}, {}, {}, {}, {} ) """.format(mysql_table, *minmax)) with mock.patch.dict('os.environ', self.env_vars): op = MySqlToHiveOperator( task_id='test_m2h', hive_cli_conn_id='hive_cli_default', sql=f"SELECT * FROM {mysql_table}", hive_table=hive_table, recreate=True, delimiter=",", dag=self.dag, ) op.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE, ignore_ti_state=True) mock_cursor = MockConnectionCursor() mock_cursor.iterable = [minmax] hive_hook = MockHiveServer2Hook(connection_cursor=mock_cursor) result = hive_hook.get_records(f"SELECT * FROM {hive_table}") assert result[0] == minmax hive_cmd = [ 'beeline', '-u', '"jdbc:hive2://localhost:10000/default"', '-hiveconf', 'airflow.ctx.dag_id=unit_test_dag', '-hiveconf', 'airflow.ctx.task_id=test_m2h', '-hiveconf', 'airflow.ctx.execution_date=2015-01-01T00:00:00+00:00', '-hiveconf', 'airflow.ctx.dag_run_id=55', '-hiveconf', 'airflow.ctx.dag_owner=airflow', '-hiveconf', '[email protected]', '-hiveconf', 'mapreduce.job.queuename=airflow', '-hiveconf', 'mapred.job.queue.name=airflow', '-hiveconf', 'tez.queue.name=airflow', '-f', '/tmp/airflow_hiveop_test_mysql_to_hive/tmptest_mysql_to_hive', ] mock_popen.assert_called_with( hive_cmd, stdout=mock_subprocess.PIPE, stderr=mock_subprocess.STDOUT, cwd="/tmp/airflow_hiveop_test_mysql_to_hive", close_fds=True, ) finally: with hook.get_conn() as conn: conn.execute(f"DROP TABLE IF EXISTS {mysql_table}")
def execute(self, context: Dict) -> None: self.log.info('Executing: %s', self.sql) hook = MySqlHook(mysql_conn_id=self.mysql_conn_id, schema=self.database) hook.run(self.sql, autocommit=self.autocommit, parameters=self.parameters)