def test_run_cli_with_hive_conf(self): hql = "set key;\n" \ "set airflow.ctx.dag_id;\nset airflow.ctx.dag_run_id;\n" \ "set airflow.ctx.task_id;\nset airflow.ctx.execution_date;\n" dag_id_ctx_var_name = \ AIRFLOW_VAR_NAME_FORMAT_MAPPING['AIRFLOW_CONTEXT_DAG_ID']['env_var_format'] task_id_ctx_var_name = \ AIRFLOW_VAR_NAME_FORMAT_MAPPING['AIRFLOW_CONTEXT_TASK_ID']['env_var_format'] execution_date_ctx_var_name = \ AIRFLOW_VAR_NAME_FORMAT_MAPPING['AIRFLOW_CONTEXT_EXECUTION_DATE'][ 'env_var_format'] dag_run_id_ctx_var_name = \ AIRFLOW_VAR_NAME_FORMAT_MAPPING['AIRFLOW_CONTEXT_DAG_RUN_ID'][ 'env_var_format'] with mock.patch.dict( 'os.environ', { dag_id_ctx_var_name: 'test_dag_id', task_id_ctx_var_name: 'test_task_id', execution_date_ctx_var_name: 'test_execution_date', dag_run_id_ctx_var_name: 'test_dag_run_id', }): hook = HiveCliHook() output = hook.run_cli(hql=hql, hive_conf={'key': 'value'}) self.assertIn('value', output) self.assertIn('test_dag_id', output) self.assertIn('test_task_id', output) self.assertIn('test_execution_date', output) self.assertIn('test_dag_run_id', output)
def execute(self, context): hive = HiveCliHook(hive_cli_conn_id=self.hive_cli_conn_id) vertica = VerticaHook(vertica_conn_id=self.vertica_conn_id) self.log.info("Dumping Vertica query results to local file") conn = vertica.get_conn() cursor = conn.cursor() cursor.execute(self.sql) with NamedTemporaryFile("w") as f: csv_writer = csv.writer(f, delimiter=self.delimiter, encoding='utf-8') field_dict = OrderedDict() col_count = 0 for field in cursor.description: col_count += 1 col_position = f"Column{col_count}" field_dict[col_position if field[0] == '' else field[0]] = self.type_map(field[1]) csv_writer.writerows(cursor.iterate()) f.flush() cursor.close() conn.close() self.log.info("Loading file into Hive") hive.load_file( f.name, self.hive_table, field_dict=field_dict, create=self.create, partition=self.partition, delimiter=self.delimiter, recreate=self.recreate, )
def test_load_file_create_table(self, mock_run_cli): filepath = "/path/to/input/file" table = "output_table" field_dict = OrderedDict([("name", "string"), ("gender", "string")]) fields = ",\n ".join([ '`{k}` {v}'.format(k=k.strip('`'), v=v) for k, v in field_dict.items() ]) hook = HiveCliHook() hook.load_file(filepath=filepath, table=table, field_dict=field_dict, create=True, recreate=True) create_table = ("DROP TABLE IF EXISTS {table};\n" "CREATE TABLE IF NOT EXISTS {table} (\n{fields})\n" "ROW FORMAT DELIMITED\n" "FIELDS TERMINATED BY ','\n" "STORED AS textfile\n;".format(table=table, fields=fields)) load_data = ("LOAD DATA LOCAL INPATH '{filepath}' " "OVERWRITE INTO TABLE {table} ;\n".format( filepath=filepath, table=table)) calls = [mock.call(create_table), mock.call(load_data)] mock_run_cli.assert_has_calls(calls, any_order=True)
def execute(self, context: Dict[str, str]): hive = HiveCliHook(hive_cli_conn_id=self.hive_cli_conn_id) mysql = MySqlHook(mysql_conn_id=self.mysql_conn_id) self.log.info("Dumping MySQL query results to local file") conn = mysql.get_conn() cursor = conn.cursor() cursor.execute(self.sql) with NamedTemporaryFile("wb") as f: csv_writer = csv.writer( f, delimiter=self.delimiter, quoting=self.quoting, quotechar=self.quotechar, escapechar=self.escapechar, encoding="utf-8", ) field_dict = OrderedDict() for field in cursor.description: field_dict[field[0]] = self.type_map(field[1]) csv_writer.writerows(cursor) f.flush() cursor.close() conn.close() self.log.info("Loading file into Hive") hive.load_file( f.name, self.hive_table, field_dict=field_dict, create=self.create, partition=self.partition, delimiter=self.delimiter, recreate=self.recreate, tblproperties=self.tblproperties, )
def test_load_df_with_data_types(self, mock_run_cli): ord_dict = OrderedDict() ord_dict['b'] = [True] ord_dict['i'] = [-1] ord_dict['t'] = [1] ord_dict['f'] = [0.0] ord_dict['c'] = ['c'] ord_dict['M'] = [datetime.datetime(2018, 1, 1)] ord_dict['O'] = [object()] ord_dict['S'] = [b'STRING'] ord_dict['U'] = ['STRING'] ord_dict['V'] = [None] df = pd.DataFrame(ord_dict) hook = HiveCliHook() hook.load_df(df, 't') query = """ CREATE TABLE IF NOT EXISTS t ( `b` BOOLEAN, `i` BIGINT, `t` BIGINT, `f` DOUBLE, `c` STRING, `M` TIMESTAMP, `O` STRING, `S` STRING, `U` STRING, `V` STRING) ROW FORMAT DELIMITED FIELDS TERMINATED BY ',' STORED AS textfile ; """ assert_equal_ignore_multiple_spaces(self, mock_run_cli.call_args_list[0][0][0], query)
def execute(self, context: Dict[str, str]): mssql = MsSqlHook(mssql_conn_id=self.mssql_conn_id) self.log.info( "Dumping Microsoft SQL Server query results to local file") with mssql.get_conn() as conn: with conn.cursor() as cursor: cursor.execute(self.sql) with NamedTemporaryFile("w") as tmp_file: csv_writer = csv.writer(tmp_file, delimiter=self.delimiter, encoding='utf-8') field_dict = OrderedDict() col_count = 0 for field in cursor.description: col_count += 1 col_position = "Column{position}".format( position=col_count) field_dict[col_position if field[0] == '' else field[0]] = self.type_map(field[1]) csv_writer.writerows(cursor) tmp_file.flush() hive = HiveCliHook(hive_cli_conn_id=self.hive_cli_conn_id) self.log.info("Loading file into Hive") hive.load_file(tmp_file.name, self.hive_table, field_dict=field_dict, create=self.create, partition=self.partition, delimiter=self.delimiter, recreate=self.recreate, tblproperties=self.tblproperties)
def ddl(self): """ Retrieve table ddl """ table = request.args.get("table") sql = "SHOW CREATE TABLE {table};".format(table=table) hook = HiveCliHook(HIVE_CLI_CONN_ID) return hook.run_cli(sql)
def test_load_file_without_create_table(self, mock_run_cli): filepath = "/path/to/input/file" table = "output_table" hook = HiveCliHook() hook.load_file(filepath=filepath, table=table, create=False) query = ("LOAD DATA LOCAL INPATH '{filepath}' " "OVERWRITE INTO TABLE {table} ;\n".format(filepath=filepath, table=table)) calls = [mock.call(query)] mock_run_cli.assert_has_calls(calls, any_order=True)
def test_get_proxy_user_value(self): hook = HiveCliHook() returner = mock.MagicMock() returner.extra_dejson = {'proxy_user': '******'} hook.use_beeline = True hook.conn = returner # Run result = hook._prepare_cli_cmd() # Verify self.assertIn('hive.server2.proxy.user=a_user_proxy', result[2])
def test_load_df_with_optional_parameters(self, mock_to_csv, mock_load_file): hook = HiveCliHook() bools = (True, False) for create, recreate in itertools.product(bools, bools): mock_load_file.reset_mock() hook.load_df(df=pd.DataFrame({"c": range(0, 10)}), table="t", create=create, recreate=recreate) assert mock_load_file.call_count == 1 kwargs = mock_load_file.call_args[1] self.assertEqual(kwargs["create"], create) self.assertEqual(kwargs["recreate"], recreate)
def get_hook(self): return HiveCliHook( hive_cli_conn_id=self.hive_cli_conn_id, run_as=self.run_as, mapred_queue=self.mapred_queue, mapred_queue_priority=self.mapred_queue_priority, mapred_job_name=self.mapred_job_name)
def get_hook(self) -> HiveCliHook: """Get Hive cli hook""" return HiveCliHook( hive_cli_conn_id=self.hive_cli_conn_id, run_as=self.run_as, mapred_queue=self.mapred_queue, mapred_queue_priority=self.mapred_queue_priority, mapred_job_name=self.mapred_job_name, )
def test_load_df(self, mock_to_csv, mock_load_file): df = pd.DataFrame({"c": ["foo", "bar", "baz"]}) table = "t" delimiter = "," encoding = "utf-8" hook = HiveCliHook() hook.load_df(df=df, table=table, delimiter=delimiter, encoding=encoding) assert mock_to_csv.call_count == 1 kwargs = mock_to_csv.call_args[1] self.assertEqual(kwargs["header"], False) self.assertEqual(kwargs["index"], False) self.assertEqual(kwargs["sep"], delimiter) assert mock_load_file.call_count == 1 kwargs = mock_load_file.call_args[1] self.assertEqual(kwargs["delimiter"], delimiter) self.assertEqual(kwargs["field_dict"], {"c": "STRING"}) self.assertTrue(isinstance(kwargs["field_dict"], OrderedDict)) self.assertEqual(kwargs["table"], table)
def execute(self, context: Dict[str, Any]) -> None: hive = HiveCliHook(hive_cli_conn_id=self.hive_cli_conn_id) self.log.info("Extracting data from Hive") hive_table = 'druid.' + context['task_instance_key_str'].replace( '.', '_') sql = self.sql.strip().strip(';') tblproperties = ''.join([ ", '{}' = '{}'".format(k, v) for k, v in self.hive_tblproperties.items() ]) hql = f"""\ SET mapred.output.compress=false; SET hive.exec.compress.output=false; DROP TABLE IF EXISTS {hive_table}; CREATE TABLE {hive_table} ROW FORMAT DELIMITED FIELDS TERMINATED BY '\t' STORED AS TEXTFILE TBLPROPERTIES ('serialization.null.format' = ''{tblproperties}) AS {sql} """ self.log.info("Running command:\n %s", hql) hive.run_cli(hql) meta_hook = HiveMetastoreHook(self.metastore_conn_id) # Get the Hive table and extract the columns table = meta_hook.get_table(hive_table) columns = [col.name for col in table.sd.cols] # Get the path on hdfs static_path = meta_hook.get_table(hive_table).sd.location druid = DruidHook(druid_ingest_conn_id=self.druid_ingest_conn_id) try: index_spec = self.construct_ingest_query( static_path=static_path, columns=columns, ) self.log.info("Inserting rows into Druid, hdfs path: %s", static_path) druid.submit_indexing_job(index_spec) self.log.info("Load seems to have succeeded!") finally: self.log.info("Cleaning up by dropping the temp Hive table %s", hive_table) hql = "DROP TABLE IF EXISTS {}".format(hive_table) hive.run_cli(hql)
def execute(self, context: 'Context'): # Downloading file from S3 s3_hook = S3Hook(aws_conn_id=self.aws_conn_id, verify=self.verify) hive_hook = HiveCliHook(hive_cli_conn_id=self.hive_cli_conn_id) self.log.info("Downloading S3 file") if self.wildcard_match: if not s3_hook.check_for_wildcard_key(self.s3_key): raise AirflowException(f"No key matches {self.s3_key}") s3_key_object = s3_hook.get_wildcard_key(self.s3_key) else: if not s3_hook.check_for_key(self.s3_key): raise AirflowException( f"The key {self.s3_key} does not exists") s3_key_object = s3_hook.get_key(self.s3_key) _, file_ext = os.path.splitext(s3_key_object.key) if self.select_expression and self.input_compressed and file_ext.lower( ) != '.gz': raise AirflowException( "GZIP is the only compression format Amazon S3 Select supports" ) with TemporaryDirectory( prefix='tmps32hive_') as tmp_dir, NamedTemporaryFile( mode="wb", dir=tmp_dir, suffix=file_ext) as f: self.log.info("Dumping S3 key %s contents to local file %s", s3_key_object.key, f.name) if self.select_expression: option = {} if self.headers: option['FileHeaderInfo'] = 'USE' if self.delimiter: option['FieldDelimiter'] = self.delimiter input_serialization = {'CSV': option} if self.input_compressed: input_serialization['CompressionType'] = 'GZIP' content = s3_hook.select_key( bucket_name=s3_key_object.bucket_name, key=s3_key_object.key, expression=self.select_expression, input_serialization=input_serialization, ) f.write(content.encode("utf-8")) else: s3_key_object.download_fileobj(f) f.flush() if self.select_expression or not self.headers: self.log.info("Loading file %s into Hive", f.name) hive_hook.load_file( f.name, self.hive_table, field_dict=self.field_dict, create=self.create, partition=self.partition, delimiter=self.delimiter, recreate=self.recreate, tblproperties=self.tblproperties, ) else: # Decompressing file if self.input_compressed: self.log.info("Uncompressing file %s", f.name) fn_uncompressed = uncompress_file(f.name, file_ext, tmp_dir) self.log.info("Uncompressed to %s", fn_uncompressed) # uncompressed file available now so deleting # compressed file to save disk space f.close() else: fn_uncompressed = f.name # Testing if header matches field_dict if self.check_headers: self.log.info("Matching file header against field_dict") header_list = self._get_top_row_as_list(fn_uncompressed) if not self._match_headers(header_list): raise AirflowException("Header check failed") # Deleting top header row self.log.info("Removing header from file %s", fn_uncompressed) headless_file = self._delete_top_row_and_compress( fn_uncompressed, file_ext, tmp_dir) self.log.info("Headless file %s", headless_file) self.log.info("Loading file %s into Hive", headless_file) hive_hook.load_file( headless_file, self.hive_table, field_dict=self.field_dict, create=self.create, partition=self.partition, delimiter=self.delimiter, recreate=self.recreate, tblproperties=self.tblproperties, )
def get_hook(self): if self.conn_type == 'mysql': from airflow.providers.mysql.hooks.mysql import MySqlHook return MySqlHook(mysql_conn_id=self.conn_id) elif self.conn_type == 'google_cloud_platform': from airflow.gcp.hooks.bigquery import BigQueryHook return BigQueryHook(bigquery_conn_id=self.conn_id) elif self.conn_type == 'postgres': from airflow.providers.postgres.hooks.postgres import PostgresHook return PostgresHook(postgres_conn_id=self.conn_id) elif self.conn_type == 'pig_cli': from airflow.providers.apache.pig.hooks.pig import PigCliHook return PigCliHook(pig_cli_conn_id=self.conn_id) elif self.conn_type == 'hive_cli': from airflow.providers.apache.hive.hooks.hive import HiveCliHook return HiveCliHook(hive_cli_conn_id=self.conn_id) elif self.conn_type == 'presto': from airflow.providers.presto.hooks.presto import PrestoHook return PrestoHook(presto_conn_id=self.conn_id) elif self.conn_type == 'hiveserver2': from airflow.providers.apache.hive.hooks.hive import HiveServer2Hook return HiveServer2Hook(hiveserver2_conn_id=self.conn_id) elif self.conn_type == 'sqlite': from airflow.providers.sqlite.hooks.sqlite import SqliteHook return SqliteHook(sqlite_conn_id=self.conn_id) elif self.conn_type == 'jdbc': from airflow.providers.jdbc.hooks.jdbc import JdbcHook return JdbcHook(jdbc_conn_id=self.conn_id) elif self.conn_type == 'mssql': from airflow.providers.microsoft.mssql.hooks.mssql import MsSqlHook return MsSqlHook(mssql_conn_id=self.conn_id) elif self.conn_type == 'odbc': from airflow.providers.odbc.hooks.odbc import OdbcHook return OdbcHook(odbc_conn_id=self.conn_id) elif self.conn_type == 'oracle': from airflow.providers.oracle.hooks.oracle import OracleHook return OracleHook(oracle_conn_id=self.conn_id) elif self.conn_type == 'vertica': from airflow.providers.vertica.hooks.vertica import VerticaHook return VerticaHook(vertica_conn_id=self.conn_id) elif self.conn_type == 'cloudant': from airflow.providers.cloudant.hooks.cloudant import CloudantHook return CloudantHook(cloudant_conn_id=self.conn_id) elif self.conn_type == 'jira': from airflow.providers.jira.hooks.jira import JiraHook return JiraHook(jira_conn_id=self.conn_id) elif self.conn_type == 'redis': from airflow.providers.redis.hooks.redis import RedisHook return RedisHook(redis_conn_id=self.conn_id) elif self.conn_type == 'wasb': from airflow.providers.microsoft.azure.hooks.wasb import WasbHook return WasbHook(wasb_conn_id=self.conn_id) elif self.conn_type == 'docker': from airflow.providers.docker.hooks.docker import DockerHook return DockerHook(docker_conn_id=self.conn_id) elif self.conn_type == 'azure_data_lake': from airflow.providers.microsoft.azure.hooks.azure_data_lake import AzureDataLakeHook return AzureDataLakeHook(azure_data_lake_conn_id=self.conn_id) elif self.conn_type == 'azure_cosmos': from airflow.providers.microsoft.azure.hooks.azure_cosmos import AzureCosmosDBHook return AzureCosmosDBHook(azure_cosmos_conn_id=self.conn_id) elif self.conn_type == 'cassandra': from airflow.providers.apache.cassandra.hooks.cassandra import CassandraHook return CassandraHook(cassandra_conn_id=self.conn_id) elif self.conn_type == 'mongo': from airflow.providers.mongo.hooks.mongo import MongoHook return MongoHook(conn_id=self.conn_id) elif self.conn_type == 'gcpcloudsql': from airflow.gcp.hooks.cloud_sql import CloudSQLDatabaseHook return CloudSQLDatabaseHook(gcp_cloudsql_conn_id=self.conn_id) elif self.conn_type == 'grpc': from airflow.providers.grpc.hooks.grpc import GrpcHook return GrpcHook(grpc_conn_id=self.conn_id) raise AirflowException("Unknown hook type {}".format(self.conn_type))
def test_run_cli(self): hook = HiveCliHook() hook.run_cli("SHOW DATABASES")