def test_get_conn_from_connection_with_schema(self, mock_connect): conn = Connection(login='******', password='******', host='host', schema='schema') hook = MySqlHook(connection=conn, schema='schema-override') hook.get_conn() mock_connect.assert_called_once_with( user='******', passwd='password-conn', host='host', db='schema-override', port=3306 )
def test_mysql_to_hive_verify_loaded_values(self): mysql_table = 'test_mysql_to_hive' hive_table = 'test_mysql_to_hive' hook = MySqlHook() try: minmax = ( 255, 65535, 16777215, 4294967295, 18446744073709551615, -128, -32768, -8388608, -2147483648, -9223372036854775808 ) with hook.get_conn() as conn: conn.execute("DROP TABLE IF EXISTS {}".format(mysql_table)) conn.execute(""" CREATE TABLE {} ( c0 TINYINT UNSIGNED, c1 SMALLINT UNSIGNED, c2 MEDIUMINT UNSIGNED, c3 INT UNSIGNED, c4 BIGINT UNSIGNED, c5 TINYINT, c6 SMALLINT, c7 MEDIUMINT, c8 INT, c9 BIGINT ) """.format(mysql_table)) conn.execute(""" INSERT INTO {} VALUES ( {}, {}, {}, {}, {}, {}, {}, {}, {}, {} ) """.format(mysql_table, *minmax)) op = MySqlToHiveTransferOperator( task_id='test_m2h', hive_cli_conn_id='hive_cli_default', sql="SELECT * FROM {}".format(mysql_table), hive_table=hive_table, recreate=True, delimiter=",", dag=self.dag) op.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE, ignore_ti_state=True) hive_hook = HiveServer2Hook() result = hive_hook.get_records("SELECT * FROM {}".format(hive_table)) self.assertEqual(result[0], minmax) finally: with hook.get_conn() as conn: conn.execute("DROP TABLE IF EXISTS {}".format(mysql_table))
def test_mysql_to_hive_verify_csv_special_char(self): mysql_table = 'test_mysql_to_hive' hive_table = 'test_mysql_to_hive' from airflow.providers.mysql.hooks.mysql import MySqlHook hook = MySqlHook() try: db_record = ( 'c0', '["true"]' ) with hook.get_conn() as conn: conn.execute("DROP TABLE IF EXISTS {}".format(mysql_table)) conn.execute(""" CREATE TABLE {} ( c0 VARCHAR(25), c1 VARCHAR(25) ) """.format(mysql_table)) conn.execute(""" INSERT INTO {} VALUES ( '{}', '{}' ) """.format(mysql_table, *db_record)) from airflow.operators.mysql_to_hive import MySqlToHiveTransfer import unicodecsv as csv op = MySqlToHiveTransfer( task_id='test_m2h', hive_cli_conn_id='hive_cli_default', sql="SELECT * FROM {}".format(mysql_table), hive_table=hive_table, recreate=True, delimiter=",", quoting=csv.QUOTE_NONE, quotechar='', escapechar='@', dag=self.dag) op.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE, ignore_ti_state=True) from airflow.providers.apache.hive.hooks.hive import HiveServer2Hook hive_hook = HiveServer2Hook() result = hive_hook.get_records("SELECT * FROM {}".format(hive_table)) self.assertEqual(result[0], db_record) finally: with hook.get_conn() as conn: conn.execute("DROP TABLE IF EXISTS {}".format(mysql_table))
def execute(self, context: Dict[str, str]): hive = HiveCliHook(hive_cli_conn_id=self.hive_cli_conn_id) mysql = MySqlHook(mysql_conn_id=self.mysql_conn_id) self.log.info("Dumping MySQL query results to local file") conn = mysql.get_conn() cursor = conn.cursor() cursor.execute(self.sql) with NamedTemporaryFile("wb") as f: csv_writer = csv.writer( f, delimiter=self.delimiter, quoting=self.quoting, quotechar=self.quotechar, escapechar=self.escapechar, encoding="utf-8", ) field_dict = OrderedDict() for field in cursor.description: field_dict[field[0]] = self.type_map(field[1]) csv_writer.writerows(cursor) f.flush() cursor.close() conn.close() self.log.info("Loading file into Hive") hive.load_file( f.name, self.hive_table, field_dict=field_dict, create=self.create, partition=self.partition, delimiter=self.delimiter, recreate=self.recreate, tblproperties=self.tblproperties, )
def test_mysql_hook_test_bulk_load(self, client): with MySqlContext(client): records = ("foo", "bar", "baz") import tempfile with tempfile.NamedTemporaryFile() as f: f.write("\n".join(records).encode('utf8')) f.flush() hook = MySqlHook('airflow_db') with closing(hook.get_conn()) as conn: with closing(conn.cursor()) as cursor: cursor.execute( """ CREATE TABLE IF NOT EXISTS test_airflow ( dummy VARCHAR(50) ) """ ) cursor.execute("TRUNCATE TABLE test_airflow") hook.bulk_load("test_airflow", f.name) cursor.execute("SELECT dummy FROM test_airflow") results = tuple(result[0] for result in cursor.fetchall()) assert sorted(results) == sorted(records)
def test_mysql_to_hive_type_conversion(self, mock_load_file): mysql_table = 'test_mysql_to_hive' hook = MySqlHook() try: with hook.get_conn() as conn: conn.execute("DROP TABLE IF EXISTS {}".format(mysql_table)) conn.execute( """ CREATE TABLE {} ( c0 TINYINT, c1 SMALLINT, c2 MEDIUMINT, c3 INT, c4 BIGINT, c5 TIMESTAMP ) """.format( mysql_table ) ) op = MySqlToHiveOperator( task_id='test_m2h', hive_cli_conn_id='hive_cli_default', sql="SELECT * FROM {}".format(mysql_table), hive_table='test_mysql_to_hive', dag=self.dag, ) op.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE, ignore_ti_state=True) assert mock_load_file.call_count == 1 ordered_dict = OrderedDict() ordered_dict["c0"] = "SMALLINT" ordered_dict["c1"] = "INT" ordered_dict["c2"] = "INT" ordered_dict["c3"] = "BIGINT" ordered_dict["c4"] = "DECIMAL(38,0)" ordered_dict["c5"] = "TIMESTAMP" self.assertEqual(mock_load_file.call_args[1]["field_dict"], ordered_dict) finally: with hook.get_conn() as conn: conn.execute("DROP TABLE IF EXISTS {}".format(mysql_table))
def query(self): """Queries mysql and returns a cursor to the results.""" mysql = MySqlHook(mysql_conn_id=self.mysql_conn_id) conn = mysql.get_conn() cursor = conn.cursor() if self.ensure_utc: # Ensure TIMESTAMP results are in UTC tz_query = "SET time_zone = '+00:00'" self.log.info('Executing: %s', tz_query) cursor.execute(tz_query) self.log.info('Executing: %s', self.sql) cursor.execute(self.sql) return cursor
def test_mysql_hook_test_bulk_load(self): records = ("foo", "bar", "baz") import tempfile with tempfile.NamedTemporaryFile() as f: f.write("\n".join(records).encode('utf8')) f.flush() hook = MySqlHook('airflow_db') with hook.get_conn() as conn: conn.execute(""" CREATE TABLE IF NOT EXISTS test_airflow ( dummy VARCHAR(50) ) """) conn.execute("TRUNCATE TABLE test_airflow") hook.bulk_load("test_airflow", f.name) conn.execute("SELECT dummy FROM test_airflow") results = tuple(result[0] for result in conn.fetchall()) self.assertEqual(sorted(results), sorted(records))
class TestMySqlHookConnMySqlConnectorPython(unittest.TestCase): def setUp(self): super().setUp() self.connection = Connection( login='******', password='******', host='host', schema='schema', extra='{"client": "mysql-connector-python"}', ) self.db_hook = MySqlHook() self.db_hook.get_connection = mock.Mock() self.db_hook.get_connection.return_value = self.connection @mock.patch('mysql.connector.connect') def test_get_conn(self, mock_connect): self.db_hook.get_conn() assert mock_connect.call_count == 1 args, kwargs = mock_connect.call_args assert args == () assert kwargs['user'] == 'login' assert kwargs['password'] == 'password' assert kwargs['host'] == 'host' assert kwargs['database'] == 'schema' @mock.patch('mysql.connector.connect') def test_get_conn_port(self, mock_connect): self.connection.port = 3307 self.db_hook.get_conn() assert mock_connect.call_count == 1 args, kwargs = mock_connect.call_args assert args == () assert kwargs['port'] == 3307 @mock.patch('mysql.connector.connect') def test_get_conn_allow_local_infile(self, mock_connect): extra_dict = self.connection.extra_dejson extra_dict.update(allow_local_infile=True) self.connection.extra = json.dumps(extra_dict) self.db_hook.get_conn() assert mock_connect.call_count == 1 args, kwargs = mock_connect.call_args assert args == () assert kwargs['allow_local_infile'] == 1
def execute(self, context): """Establish connections to both MySQL & PostgreSQL databases, open cursor and begin processing query, loading chunks of rows into PostgreSQL. Repeat loading chunks until all rows processed for query. """ source = MySqlHook(mysql_conn_id=self.mysql_conn_id) target = PostgresHook(postgres_conn_id=self.postgres_conn_id) with closing(source.get_conn()) as conn: with closing(conn.cursor()) as cursor: cursor.execute(self.sql, self.params) target_fields = [x[0] for x in cursor.description] row_count = 0 rows = cursor.fetchmany(self.rows_chunk) while len(rows) > 0: row_count += len(rows) target.insert_rows( self.postgres_table, rows, target_fields=target_fields, commit_every=self.rows_chunk, ) rows = cursor.fetchmany(self.rows_chunk) self.log.info( f"{row_count} row(s) inserted into {self.postgres_table}.")
class TestMySqlHookConn(unittest.TestCase): def setUp(self): super().setUp() self.connection = Connection( conn_type='mysql', login='******', password='******', host='host', schema='schema', ) self.db_hook = MySqlHook() self.db_hook.get_connection = mock.Mock() self.db_hook.get_connection.return_value = self.connection @mock.patch('MySQLdb.connect') def test_get_conn(self, mock_connect): self.db_hook.get_conn() assert mock_connect.call_count == 1 args, kwargs = mock_connect.call_args assert args == () assert kwargs['user'] == 'login' assert kwargs['passwd'] == 'password' assert kwargs['host'] == 'host' assert kwargs['db'] == 'schema' @mock.patch('MySQLdb.connect') def test_get_uri(self, mock_connect): self.connection.extra = json.dumps({'charset': 'utf-8'}) self.db_hook.get_conn() assert mock_connect.call_count == 1 args, kwargs = mock_connect.call_args assert self.db_hook.get_uri( ) == "mysql://*****:*****@host/schema?charset=utf-8" @mock.patch('MySQLdb.connect') def test_get_conn_from_connection(self, mock_connect): conn = Connection(login='******', password='******', host='host', schema='schema') hook = MySqlHook(connection=conn) hook.get_conn() mock_connect.assert_called_once_with(user='******', passwd='password-conn', host='host', db='schema', port=3306) @mock.patch('MySQLdb.connect') def test_get_conn_from_connection_with_schema(self, mock_connect): conn = Connection(login='******', password='******', host='host', schema='schema') hook = MySqlHook(connection=conn, schema='schema-override') hook.get_conn() mock_connect.assert_called_once_with(user='******', passwd='password-conn', host='host', db='schema-override', port=3306) @mock.patch('MySQLdb.connect') def test_get_conn_port(self, mock_connect): self.connection.port = 3307 self.db_hook.get_conn() assert mock_connect.call_count == 1 args, kwargs = mock_connect.call_args assert args == () assert kwargs['port'] == 3307 @mock.patch('MySQLdb.connect') def test_get_conn_charset(self, mock_connect): self.connection.extra = json.dumps({'charset': 'utf-8'}) self.db_hook.get_conn() assert mock_connect.call_count == 1 args, kwargs = mock_connect.call_args assert args == () assert kwargs['charset'] == 'utf-8' assert kwargs['use_unicode'] is True @mock.patch('MySQLdb.connect') def test_get_conn_cursor(self, mock_connect): self.connection.extra = json.dumps({'cursor': 'sscursor'}) self.db_hook.get_conn() assert mock_connect.call_count == 1 args, kwargs = mock_connect.call_args assert args == () assert kwargs['cursorclass'] == MySQLdb.cursors.SSCursor @mock.patch('MySQLdb.connect') def test_get_conn_local_infile(self, mock_connect): self.connection.extra = json.dumps({'local_infile': True}) self.db_hook.get_conn() assert mock_connect.call_count == 1 args, kwargs = mock_connect.call_args assert args == () assert kwargs['local_infile'] == 1 @mock.patch('MySQLdb.connect') def test_get_con_unix_socket(self, mock_connect): self.connection.extra = json.dumps({'unix_socket': "/tmp/socket"}) self.db_hook.get_conn() assert mock_connect.call_count == 1 args, kwargs = mock_connect.call_args assert args == () assert kwargs['unix_socket'] == '/tmp/socket' @mock.patch('MySQLdb.connect') def test_get_conn_ssl_as_dictionary(self, mock_connect): self.connection.extra = json.dumps({'ssl': SSL_DICT}) self.db_hook.get_conn() assert mock_connect.call_count == 1 args, kwargs = mock_connect.call_args assert args == () assert kwargs['ssl'] == SSL_DICT @mock.patch('MySQLdb.connect') def test_get_conn_ssl_as_string(self, mock_connect): self.connection.extra = json.dumps({'ssl': json.dumps(SSL_DICT)}) self.db_hook.get_conn() assert mock_connect.call_count == 1 args, kwargs = mock_connect.call_args assert args == () assert kwargs['ssl'] == SSL_DICT @mock.patch('MySQLdb.connect') @mock.patch( 'airflow.providers.amazon.aws.hooks.base_aws.AwsBaseHook.get_client_type' ) def test_get_conn_rds_iam(self, mock_client, mock_connect): self.connection.extra = '{"iam":true}' mock_client.return_value.generate_db_auth_token.return_value = 'aws_token' self.db_hook.get_conn() mock_connect.assert_called_once_with( user='******', passwd='aws_token', host='host', db='schema', port=3306, read_default_group='enable-cleartext-plugin', )
def execute(self, context): vertica = VerticaHook(vertica_conn_id=self.vertica_conn_id) mysql = MySqlHook(mysql_conn_id=self.mysql_conn_id) tmpfile = None result = None selected_columns = [] count = 0 with closing(vertica.get_conn()) as conn: with closing(conn.cursor()) as cursor: cursor.execute(self.sql) selected_columns = [d.name for d in cursor.description] if self.bulk_load: tmpfile = NamedTemporaryFile("w") self.log.info( "Selecting rows from Vertica to local file %s...", tmpfile.name) self.log.info(self.sql) csv_writer = csv.writer(tmpfile, delimiter='\t', encoding='utf-8') for row in cursor.iterate(): csv_writer.writerow(row) count += 1 tmpfile.flush() else: self.log.info("Selecting rows from Vertica...") self.log.info(self.sql) result = cursor.fetchall() count = len(result) self.log.info("Selected rows from Vertica %s", count) if self.mysql_preoperator: self.log.info("Running MySQL preoperator...") mysql.run(self.mysql_preoperator) try: if self.bulk_load: self.log.info("Bulk inserting rows into MySQL...") with closing(mysql.get_conn()) as conn: with closing(conn.cursor()) as cursor: cursor.execute( "LOAD DATA LOCAL INFILE '%s' INTO " "TABLE %s LINES TERMINATED BY '\r\n' (%s)" % (tmpfile.name, self.mysql_table, ", ".join(selected_columns))) conn.commit() tmpfile.close() else: self.log.info("Inserting rows into MySQL...") mysql.insert_rows(table=self.mysql_table, rows=result, target_fields=selected_columns) self.log.info("Inserted rows into MySQL %s", count) except (MySQLdb.Error, MySQLdb.Warning): # pylint: disable=no-member self.log.info("Inserted rows into MySQL 0") raise if self.mysql_postoperator: self.log.info("Running MySQL postoperator...") mysql.run(self.mysql_postoperator) self.log.info("Done")
def test_mysql_to_hive_verify_loaded_values(self, mock_popen, mock_temp_dir): mock_subprocess = MockSubProcess() mock_popen.return_value = mock_subprocess mock_temp_dir.return_value = "test_mysql_to_hive" mysql_table = 'test_mysql_to_hive' hive_table = 'test_mysql_to_hive' hook = MySqlHook() try: minmax = ( 255, 65535, 16777215, 4294967295, 18446744073709551615, -128, -32768, -8388608, -2147483648, -9223372036854775808, ) with hook.get_conn() as conn: conn.execute(f"DROP TABLE IF EXISTS {mysql_table}") conn.execute(""" CREATE TABLE {} ( c0 TINYINT UNSIGNED, c1 SMALLINT UNSIGNED, c2 MEDIUMINT UNSIGNED, c3 INT UNSIGNED, c4 BIGINT UNSIGNED, c5 TINYINT, c6 SMALLINT, c7 MEDIUMINT, c8 INT, c9 BIGINT ) """.format(mysql_table)) conn.execute(""" INSERT INTO {} VALUES ( {}, {}, {}, {}, {}, {}, {}, {}, {}, {} ) """.format(mysql_table, *minmax)) with mock.patch.dict('os.environ', self.env_vars): op = MySqlToHiveOperator( task_id='test_m2h', hive_cli_conn_id='hive_cli_default', sql=f"SELECT * FROM {mysql_table}", hive_table=hive_table, recreate=True, delimiter=",", dag=self.dag, ) op.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE, ignore_ti_state=True) mock_cursor = MockConnectionCursor() mock_cursor.iterable = [minmax] hive_hook = MockHiveServer2Hook(connection_cursor=mock_cursor) result = hive_hook.get_records(f"SELECT * FROM {hive_table}") assert result[0] == minmax hive_cmd = [ 'beeline', '-u', '"jdbc:hive2://localhost:10000/default"', '-hiveconf', 'airflow.ctx.dag_id=unit_test_dag', '-hiveconf', 'airflow.ctx.task_id=test_m2h', '-hiveconf', 'airflow.ctx.execution_date=2015-01-01T00:00:00+00:00', '-hiveconf', 'airflow.ctx.dag_run_id=55', '-hiveconf', 'airflow.ctx.dag_owner=airflow', '-hiveconf', '[email protected]', '-hiveconf', 'mapreduce.job.queuename=airflow', '-hiveconf', 'mapred.job.queue.name=airflow', '-hiveconf', 'tez.queue.name=airflow', '-f', '/tmp/airflow_hiveop_test_mysql_to_hive/tmptest_mysql_to_hive', ] mock_popen.assert_called_with( hive_cmd, stdout=mock_subprocess.PIPE, stderr=mock_subprocess.STDOUT, cwd="/tmp/airflow_hiveop_test_mysql_to_hive", close_fds=True, ) finally: with hook.get_conn() as conn: conn.execute(f"DROP TABLE IF EXISTS {mysql_table}")
def test_mysql_to_hive_verify_csv_special_char(self, mock_popen, mock_temp_dir): mock_subprocess = MockSubProcess() mock_popen.return_value = mock_subprocess mock_temp_dir.return_value = "test_mysql_to_hive" mysql_table = 'test_mysql_to_hive' hive_table = 'test_mysql_to_hive' hook = MySqlHook() try: db_record = ('c0', '["true"]') with hook.get_conn() as conn: conn.execute(f"DROP TABLE IF EXISTS {mysql_table}") conn.execute(""" CREATE TABLE {} ( c0 VARCHAR(25), c1 VARCHAR(25) ) """.format(mysql_table)) conn.execute(""" INSERT INTO {} VALUES ( '{}', '{}' ) """.format(mysql_table, *db_record)) with mock.patch.dict('os.environ', self.env_vars): import unicodecsv as csv op = MySqlToHiveOperator( task_id='test_m2h', hive_cli_conn_id='hive_cli_default', sql=f"SELECT * FROM {mysql_table}", hive_table=hive_table, recreate=True, delimiter=",", quoting=csv.QUOTE_NONE, quotechar='', escapechar='@', dag=self.dag, ) op.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE, ignore_ti_state=True) mock_cursor = MockConnectionCursor() mock_cursor.iterable = [('c0', '["true"]'), (2, 2)] hive_hook = MockHiveServer2Hook(connection_cursor=mock_cursor) result = hive_hook.get_records(f"SELECT * FROM {hive_table}") assert result[0] == db_record hive_cmd = [ 'beeline', '-u', '"jdbc:hive2://localhost:10000/default"', '-hiveconf', 'airflow.ctx.dag_id=unit_test_dag', '-hiveconf', 'airflow.ctx.task_id=test_m2h', '-hiveconf', 'airflow.ctx.execution_date=2015-01-01T00:00:00+00:00', '-hiveconf', 'airflow.ctx.dag_run_id=55', '-hiveconf', 'airflow.ctx.dag_owner=airflow', '-hiveconf', '[email protected]', '-hiveconf', 'mapreduce.job.queuename=airflow', '-hiveconf', 'mapred.job.queue.name=airflow', '-hiveconf', 'tez.queue.name=airflow', '-f', '/tmp/airflow_hiveop_test_mysql_to_hive/tmptest_mysql_to_hive', ] mock_popen.assert_called_with( hive_cmd, stdout=mock_subprocess.PIPE, stderr=mock_subprocess.STDOUT, cwd="/tmp/airflow_hiveop_test_mysql_to_hive", close_fds=True, ) finally: with hook.get_conn() as conn: conn.execute(f"DROP TABLE IF EXISTS {mysql_table}")