def test_get_results_with_schema(self): from airflow.hooks.hive_hooks import HiveServer2Hook from unittest.mock import MagicMock # Configure sql = "select 1" schema = "notdefault" hook = HiveServer2Hook() cursor_mock = MagicMock( __enter__=cursor_mock, __exit__=None, execute=None, fetchall=[], ) get_conn_mock = MagicMock( __enter__=get_conn_mock, __exit__=None, cursor=cursor_mock, ) hook.get_conn = get_conn_mock # Run hook.get_results(sql, schema) # Verify get_conn_mock.assert_called_with(self.nondefault_schema)
def test_get_pandas_df(self): hook = HiveServer2Hook() query = "SELECT * FROM {}".format(self.table) df = hook.get_pandas_df(query, schema=self.database) self.assertEqual(len(df), 2) self.assertListEqual(df.columns.tolist(), self.columns) self.assertListEqual(df[self.columns[0]].values.tolist(), [1, 2])
def add_partition(v_execution_date, v_execution_day, v_execution_hour, target_table_db_name, target_table_name, conn_id, hive_table_name, server_name, hive_db, is_must_have_data, **kwargs): # 生成_SUCCESS """ 第一个参数true: 数据目录是有country_code分区。false 没有 第二个参数true: 数据有才生成_SUCCESS false 数据没有也生成_SUCCESS """ TaskTouchzSuccess().countries_touchz_success( v_execution_day, hive_db, hive_table_name, OSS_PATH % ("{server_name}.{db_name}.{table_name}".format( server_name=server_name, db_name=target_table_db_name, table_name=target_table_name)), "false", is_must_have_data, v_execution_hour) sql = ''' ALTER TABLE {hive_db}.{table} ADD IF NOT EXISTS PARTITION (dt = '{ds}', hour = '{hour}') '''.format(hive_db=hive_db, table=hive_table_name, ds=v_execution_day, hour=v_execution_hour) hive2_conn = HiveServer2Hook().get_conn() cursor = hive2_conn.cursor() cursor.execute(sql) return
def execute(self, context): hive = HiveServer2Hook(hiveserver2_conn_id=self.hiveserver2_conn_id) logging.info("Extracting data from Hive") logging.info(self.sql) if self.bulk_load: tmpfile = NamedTemporaryFile() hive.to_csv(self.sql, tmpfile.name, delimiter='\t', lineterminator='\n', output_header=False) else: results = hive.get_records(self.sql) mysql = MySqlHook(mysql_conn_id=self.mysql_conn_id) if self.mysql_preoperator: logging.info("Running MySQL preoperator") mysql.run(self.mysql_preoperator) logging.info("Inserting rows into MySQL") if self.bulk_load: mysql.bulk_load(table=self.mysql_table, tmp_file=tmpfile.name) tmpfile.close() else: mysql.insert_rows(table=self.mysql_table, rows=results) if self.mysql_postoperator: logging.info("Running MySQL postoperator") mysql.run(self.mysql_postoperator) logging.info("Done.")
def test_get_results_with_hive_conf(self): hql = [ "set key", "set airflow.ctx.dag_id", "set airflow.ctx.dag_run_id", "set airflow.ctx.task_id", "set airflow.ctx.execution_date" ] dag_id_ctx_var_name = \ AIRFLOW_VAR_NAME_FORMAT_MAPPING['AIRFLOW_CONTEXT_DAG_ID']['env_var_format'] task_id_ctx_var_name = \ AIRFLOW_VAR_NAME_FORMAT_MAPPING['AIRFLOW_CONTEXT_TASK_ID']['env_var_format'] execution_date_ctx_var_name = \ AIRFLOW_VAR_NAME_FORMAT_MAPPING['AIRFLOW_CONTEXT_EXECUTION_DATE'][ 'env_var_format'] dag_run_id_ctx_var_name = \ AIRFLOW_VAR_NAME_FORMAT_MAPPING['AIRFLOW_CONTEXT_DAG_RUN_ID'][ 'env_var_format'] os.environ[dag_id_ctx_var_name] = 'test_dag_id' os.environ[task_id_ctx_var_name] = 'test_task_id' os.environ[execution_date_ctx_var_name] = 'test_execution_date' os.environ[dag_run_id_ctx_var_name] = 'test_dag_run_id' hook = HiveServer2Hook() output = '\n'.join(res_tuple[0] for res_tuple in hook.get_results( hql=hql, hive_conf={'key': 'value'})['data']) self.assertIn('value', output) self.assertIn('test_dag_id', output) self.assertIn('test_task_id', output) self.assertIn('test_execution_date', output) self.assertIn('test_dag_run_id', output) del os.environ[dag_id_ctx_var_name] del os.environ[task_id_ctx_var_name] del os.environ[execution_date_ctx_var_name] del os.environ[dag_run_id_ctx_var_name]
def test_multi_statements(self): from airflow.hooks.hive_hooks import HiveServer2Hook sqls = [ "CREATE TABLE IF NOT EXISTS test_multi_statements (i INT)", "DROP TABLE test_multi_statements", ] hook = HiveServer2Hook() hook.get_records(sqls)
def execute(self, context): samba = SambaHook(samba_conn_id=self.samba_conn_id) hive = HiveServer2Hook(hiveserver2_conn_id=self.hiveserver2_conn_id) tmpfile = tempfile.NamedTemporaryFile() self.log.info("Fetching file from Hive") hive.to_csv(hql=self.hql, csv_filepath=tmpfile.name) self.log.info("Pushing to samba") samba.push_from_local(self.destination_filepath, tmpfile.name)
def validate_all_hi_table_exist_task(hive_all_hi_table_name, mysql_table_name, **kwargs): check_sql = 'show partitions %s.%s' % (HIVE_DB, hive_all_hi_table_name) hive2_conn = HiveServer2Hook().get_conn() cursor = hive2_conn.cursor() cursor.execute(check_sql) if len(cursor.fetchall()) == 0: return 'import_table_{}'.format(mysql_table_name) else: return 'add_partitions_{}'.format(hive_all_hi_table_name)
def execute(self, context): samba = SambaHook(samba_conn_id=self.samba_conn_id) hive = HiveServer2Hook(hiveserver2_conn_id=self.hiveserver2_conn_id) with NamedTemporaryFile() as tmp_file: self.log.info("Fetching file from Hive") hive.to_csv(hql=self.hql, csv_filepath=tmp_file.name, hive_conf=context_to_airflow_vars(context)) self.log.info("Pushing to samba") samba.push_from_local(self.destination_filepath, tmp_file.name)
def test_multi_statements(self): sqls = [ "CREATE TABLE IF NOT EXISTS test_multi_statements (i INT)", "SELECT * FROM {}".format(self.table), "DROP TABLE test_multi_statements", ] hook = HiveServer2Hook() results = hook.get_records(sqls, schema=self.database) self.assertListEqual(results, [(1, 1), (2, 2)])
def test_to_csv(self): hook = HiveServer2Hook() query = "SELECT * FROM {}".format(self.table) csv_filepath = 'query_results.csv' hook.to_csv(query, csv_filepath, schema=self.database, delimiter=',', lineterminator='\n', output_header=True) df = pd.read_csv(csv_filepath, sep=',') self.assertListEqual(df.columns.tolist(), self.columns) self.assertListEqual(df[self.columns[0]].values.tolist(), [1, 2]) self.assertEqual(len(df), 2)
def test_mysql_to_hive_verify_loaded_values(self): mysql_conn_id = 'airflow_ci' mysql_table = 'test_mysql_to_hive' hive_table = 'test_mysql_to_hive' from airflow.hooks.mysql_hook import MySqlHook m = MySqlHook(mysql_conn_id) try: minmax = (255, 65535, 16777215, 4294967295, 18446744073709551615, -128, -32768, -8388608, -2147483648, -9223372036854775808) with m.get_conn() as c: c.execute("DROP TABLE IF EXISTS {}".format(mysql_table)) c.execute(""" CREATE TABLE {} ( c0 TINYINT UNSIGNED, c1 SMALLINT UNSIGNED, c2 MEDIUMINT UNSIGNED, c3 INT UNSIGNED, c4 BIGINT UNSIGNED, c5 TINYINT, c6 SMALLINT, c7 MEDIUMINT, c8 INT, c9 BIGINT ) """.format(mysql_table)) c.execute(""" INSERT INTO {} VALUES ( {}, {}, {}, {}, {}, {}, {}, {}, {}, {} ) """.format(mysql_table, *minmax)) from airflow.operators.mysql_to_hive import MySqlToHiveTransfer t = MySqlToHiveTransfer(task_id='test_m2h', mysql_conn_id=mysql_conn_id, hive_cli_conn_id='beeline_default', sql="SELECT * FROM {}".format(mysql_table), hive_table=hive_table, recreate=True, delimiter=",", dag=self.dag) t.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE, ignore_ti_state=True) from airflow.hooks.hive_hooks import HiveServer2Hook h = HiveServer2Hook() r = h.get_records("SELECT * FROM {}".format(hive_table)) self.assertEqual(r[0], minmax) finally: with m.get_conn() as c: c.execute("DROP TABLE IF EXISTS {}".format(mysql_table))
def add_hi_partition(v_execution_date, v_execution_day, v_execution_hour, db_name, table_name, conn_id, hive_table_name, server_name, hive_db, is_must_have_data, **kwargs): sql = ''' ALTER TABLE {hive_db}.{table} ADD IF NOT EXISTS PARTITION (dt = '{ds}', hour = '{hour}') '''.format(hive_db=hive_db, table=hive_table_name, ds=v_execution_day, hour=v_execution_hour) hive2_conn = HiveServer2Hook().get_conn() cursor = hive2_conn.cursor() cursor.execute(sql) return
def get_hook(self): try: if self.conn_type == 'mysql': from airflow.hooks.mysql_hook import MySqlHook return MySqlHook(mysql_conn_id=self.conn_id) elif self.conn_type == 'google_cloud_platform': from airflow.contrib.hooks.bigquery_hook import BigQueryHook return BigQueryHook(bigquery_conn_id=self.conn_id) elif self.conn_type == 'postgres': from airflow.hooks.postgres_hook import PostgresHook return PostgresHook(postgres_conn_id=self.conn_id) elif self.conn_type == 'hive_cli': from airflow.hooks.hive_hooks import HiveCliHook return HiveCliHook(hive_cli_conn_id=self.conn_id) elif self.conn_type == 'presto': from airflow.hooks.presto_hook import PrestoHook return PrestoHook(presto_conn_id=self.conn_id) elif self.conn_type == 'hiveserver2': from airflow.hooks.hive_hooks import HiveServer2Hook return HiveServer2Hook(hiveserver2_conn_id=self.conn_id) elif self.conn_type == 'sqlite': from airflow.hooks.sqlite_hook import SqliteHook return SqliteHook(sqlite_conn_id=self.conn_id) elif self.conn_type == 'jdbc': from airflow.hooks.jdbc_hook import JdbcHook return JdbcHook(jdbc_conn_id=self.conn_id) elif self.conn_type == 'mssql': from airflow.hooks.mssql_hook import MsSqlHook return MsSqlHook(mssql_conn_id=self.conn_id) elif self.conn_type == 'oracle': from airflow.hooks.oracle_hook import OracleHook return OracleHook(oracle_conn_id=self.conn_id) elif self.conn_type == 'vertica': from airflow.contrib.hooks.vertica_hook import VerticaHook return VerticaHook(vertica_conn_id=self.conn_id) elif self.conn_type == 'cloudant': from airflow.contrib.hooks.cloudant_hook import CloudantHook return CloudantHook(cloudant_conn_id=self.conn_id) elif self.conn_type == 'jira': from airflow.contrib.hooks.jira_hook import JiraHook return JiraHook(jira_conn_id=self.conn_id) elif self.conn_type == 'redis': from airflow.contrib.hooks.redis_hook import RedisHook return RedisHook(redis_conn_id=self.conn_id) elif self.conn_type == 'wasb': from airflow.contrib.hooks.wasb_hook import WasbHook return WasbHook(wasb_conn_id=self.conn_id) elif self.conn_type == 'docker': from airflow.hooks.docker_hook import DockerHook return DockerHook(docker_conn_id=self.conn_id) except: pass
def test_to_csv(self): hook = HiveServer2Hook() query = "SELECT * FROM {}".format(self.table) csv_filepath = 'query_results.csv' with self.assertLogs() as cm: hook.to_csv(query, csv_filepath, schema=self.database, delimiter=',', lineterminator='\n', output_header=True, fetch_size=2) df = pd.read_csv(csv_filepath, sep=',') self.assertListEqual(df.columns.tolist(), self.columns) self.assertListEqual(df[self.columns[0]].values.tolist(), [1, 2]) self.assertEqual(len(df), 2) self.assertIn('INFO:airflow.hooks.hive_hooks.HiveServer2Hook:' 'Written 2 rows so far.', cm.output)
def test_select_conn_with_schema(self, connect_mock): from airflow.hooks.hive_hooks import HiveServer2Hook # Configure hook = HiveServer2Hook() # Run hook.get_conn(self.nondefault_schema) # Verify assert connect_mock.called (args, kwargs) = connect_mock.call_args_list[0] assert kwargs['database'] == self.nondefault_schema
def count_data_rows(templates_dict, **kwargs): hook = HiveServer2Hook() query = """ SELECT count(*) FROM mydata """ result = hook.get_results(schema=templates_dict['schema'], hql=query) if result['data'][0][0] > 100: return 'clean_data' else: return 'stop_flow'
def test_get_records_with_schema(self, get_results_mock): from airflow.hooks.hive_hooks import HiveServer2Hook # Configure sql = "select 1" hook = HiveServer2Hook() # Run hook.get_records(sql, self.nondefault_schema) # Verify self.assertTrue(self.connect_mock.called) (args, kwargs) = self.connect_mock.call_args_list[0] self.assertEqual(sql, args[0]) self.assertEqual(self.nondefault_schema, kwargs['schema'])
def test_get_pandas_df_with_schema(self, get_results_mock): from airflow.hooks.hive_hooks import HiveServer2Hook # Configure sql = "select 1" hook = HiveServer2Hook() # Run hook.get_pandas_df(sql, self.nondefault_schema) # Verify assert self.connect_mock.called (args, kwargs) = self.connect_mock.call_args_list[0] assert args[0] == sql assert kwargs['schema'] == self.nondefault_schema
def test_mysql_to_hive_verify_csv_special_char(self): mysql_table = 'test_mysql_to_hive' hive_table = 'test_mysql_to_hive' from airflow.hooks.mysql_hook import MySqlHook hook = MySqlHook() try: db_record = ( 'c0', '["true"]' ) with hook.get_conn() as conn: conn.execute("DROP TABLE IF EXISTS {}".format(mysql_table)) conn.execute(""" CREATE TABLE {} ( c0 VARCHAR(25), c1 VARCHAR(25) ) """.format(mysql_table)) conn.execute(""" INSERT INTO {} VALUES ( '{}', '{}' ) """.format(mysql_table, *db_record)) from airflow.operators.mysql_to_hive import MySqlToHiveTransfer import unicodecsv as csv op = MySqlToHiveTransfer( task_id='test_m2h', hive_cli_conn_id='hive_cli_default', sql="SELECT * FROM {}".format(mysql_table), hive_table=hive_table, recreate=True, delimiter=",", quoting=csv.QUOTE_NONE, quotechar='', escapechar='@', dag=self.dag) op.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE, ignore_ti_state=True) from airflow.hooks.hive_hooks import HiveServer2Hook hive_hook = HiveServer2Hook() result = hive_hook.get_records("SELECT * FROM {}".format(hive_table)) self.assertEqual(result[0], db_record) finally: with hook.get_conn() as conn: conn.execute("DROP TABLE IF EXISTS {}".format(mysql_table))
def execute(self, context): hive = HiveServer2Hook(hiveserver2_conn_id=self.hiveserver2_conn_id) tmpfile = tempfile.NamedTemporaryFile() self.log.info("Fetching file from Hive") hive.to_csv(hql=self.hql, csv_filepath=tmpfile.name) self.log.info("Pushing to localfile") if not os.path.exists(self.dst_path): os.makedirs(self.dst_path) target = os.path.join( self.dst_path, self.dst_filename) if self.dst_filename else self.dst_path shutil.copy(tmpfile.name, target)
def test_get_conn_with_password(self, mock_connect): conn_id = "conn_with_password" conn_env = CONN_ENV_PREFIX + conn_id.upper() with patch.dict( 'os.environ', {conn_env: "jdbc+hive2://conn_id:conn_pass@localhost:10000/default?authMechanism=LDAP"} ): HiveServer2Hook(hiveserver2_conn_id=conn_id).get_conn() mock_connect.assert_called_once_with( host='localhost', port=10000, auth='LDAP', kerberos_service_name=None, username='******', password='******', database='default')
def test_get_conn_with_password(self, mock_connect): from airflow.hooks.base_hook import CONN_ENV_PREFIX conn_id = "conn_with_password" conn_env = CONN_ENV_PREFIX + conn_id.upper() conn_value = os.environ.get(conn_env) os.environ[conn_env] = "jdbc+hive2://conn_id:conn_pass@localhost:10000/default?authMechanism=LDAP" HiveServer2Hook(hiveserver2_conn_id=conn_id).get_conn() mock_connect.assert_called_with( host='localhost', port=10000, auth='LDAP', kerberos_service_name=None, username='******', password='******', database='default') if conn_value: os.environ[conn_env] = conn_value
def get_records(self): """Executes a query to obtain a count of records on a table Returns: int -- quantity of records from a count query """ if self.query_engine == 'hive': hook = HiveServer2Hook(self.query_engine_conn_id) elif self.query_engine == 'presto': hook = PrestoHook(self.query_engine_conn_id) # executes query to Hive or Presto res = hook.get_records(self.records_query) if len(res) > 1: raise else: return res[0]
def execute(self, context): hive = HiveServer2Hook(hiveserver2_conn_id=self.hiveserver2_conn_id) logging.info('Extracting data from Hive') logging.info(self.hql) data = hive.get_pandas_df(self.hql, schema=self.schema) gcp_hook = GoogleCloudStorageHook(google_cloud_storage_conn_id=self.google_cloud_storage_conn_id) logging.info('Inserting rows onto google cloud storage') with tempfile.NamedTemporaryFile(suffix='.json', prefix='tmp') as tmp_file: data = data.to_json(orient='records') recs = json.loads(data) for record in recs: tmp_file.write(json.dumps(record)) tmp_file.write("\n") tmp_file.flush() remote_file_name = self.file_pattern.format('aa') remote_name = os.path.join(self.subdir, remote_file_name) gcp_hook.upload(self.bucket, remote_name, tmp_file.name) logging.info('Done.')
def execute(self, context): hive = HiveServer2Hook(hiveserver2_conn_id=self.hiveserver2_conn_id) self.log.info('Extracting data from Hive') self.log.info(self.sql) data = hive.get_pandas_df(self.sql, schema=self.schema) dynamodb = AwsDynamoDBHook(aws_conn_id=self.aws_conn_id, table_name=self.table_name, table_keys=self.table_keys, region_name=self.region_name) self.log.info('Inserting rows into dynamodb') if self.pre_process is None: dynamodb.write_batch_data( json.loads(data.to_json(orient='records'))) else: dynamodb.write_batch_data( self.pre_process(data=data, args=self.pre_process_args, kwargs=self.pre_process_kwargs)) self.log.info('Done.')
def run_check_table(schema_table_db_name, schema_table_name, target_table_db_name, target_table_name, conn_id, hive_table_name, server_name, **kwargs): # SHOW TABLES in oride_db LIKE 'data_aa' check_sql = 'SHOW TABLES in %s LIKE \'%s\'' % (HIVE_DB, hive_table_name) hive2_conn = HiveServer2Hook().get_conn() cursor = hive2_conn.cursor() cursor.execute(check_sql) if len(cursor.fetchall()) == 0: logging.info('Create Hive Table: %s.%s', HIVE_DB, hive_table_name) # get table column column_sql = ''' SELECT COLUMN_NAME, DATA_TYPE, NUMERIC_PRECISION, NUMERIC_SCALE, COLUMN_COMMENT FROM information_schema.columns WHERE table_schema='{db_name}' and table_name='{table_name}' '''.format(db_name=schema_table_db_name, table_name=schema_table_name) mysql_hook = MySqlHook(conn_id) mysql_conn = mysql_hook.get_conn() mysql_cursor = mysql_conn.cursor() mysql_cursor.execute(column_sql) results = mysql_cursor.fetchall() rows = [] for result in results: if result[0] == 'dt': col_name = '_dt' else: col_name = result[0] if result[1] == 'timestamp' or result[1] == 'varchar' or result[1] == 'char' or result[1] == 'text' or \ result[1] == 'longtext' or \ result[1] == 'mediumtext' or \ result[1] == 'json' or \ result[1] == 'datetime': data_type = 'string' elif result[1] == 'decimal': data_type = result[1] + "(" + str(result[2]) + "," + str( result[3]) + ")" else: data_type = result[1] rows.append("`%s` %s comment '%s'" % (col_name, data_type, str(result[4]).replace( '\n', '').replace('\r', ''))) mysql_conn.close() # hive create table hive_hook = HiveCliHook() sql = ODS_CREATE_TABLE_SQL.format( db_name=HIVE_DB, table_name=hive_table_name, columns=",\n".join(rows), oss_path=OSS_PATH % ("{server_name}.{db_name}.{table_name}".format( server_name=server_name, db_name=target_table_db_name, table_name=target_table_name))) logging.info('Executing: %s', sql) hive_hook.run_cli(sql) else: sqoopSchema = SqoopSchemaUpdate() response = sqoopSchema.append_hive_schema( hive_db=HIVE_DB, hive_table=hive_table_name, mysql_db=schema_table_db_name, mysql_table=schema_table_name, mysql_conn=conn_id, oss_path=OSS_PATH % ("{server_name}.{db_name}.{table_name}".format( server_name=server_name, db_name=target_table_db_name, table_name=target_table_name))) if response: return True return
def test_get_conn(self): hook = HiveServer2Hook() hook.get_conn()
def test_get_results_header(self): hook = HiveServer2Hook() query = "SELECT * FROM {}".format(self.table) results = hook.get_results(query, schema=self.database) self.assertListEqual([col[0] for col in results['header']], self.columns)
def test_get_results_data(self): hook = HiveServer2Hook() query = "SELECT * FROM {}".format(self.table) results = hook.get_results(query, schema=self.database) self.assertListEqual(results['data'], [(1, 1), (2, 2)])