def test_load_file_create_table(self, mock_run_cli): filepath = "/path/to/input/file" table = "output_table" field_dict = OrderedDict([("name", "string"), ("gender", "string")]) fields = ",\n ".join([ '`{k}` {v}'.format(k=k.strip('`'), v=v) for k, v in field_dict.items() ]) hook = MockHiveCliHook() hook.load_file(filepath=filepath, table=table, field_dict=field_dict, create=True, recreate=True) create_table = ("DROP TABLE IF EXISTS {table};\n" "CREATE TABLE IF NOT EXISTS {table} (\n{fields})\n" "ROW FORMAT DELIMITED\n" "FIELDS TERMINATED BY ','\n" "STORED AS textfile\n;".format(table=table, fields=fields)) load_data = "LOAD DATA LOCAL INPATH '{filepath}' " "OVERWRITE INTO TABLE {table} ;\n".format( filepath=filepath, table=table) calls = [mock.call(create_table), mock.call(load_data)] mock_run_cli.assert_has_calls(calls, any_order=True)
def test_load_df_with_data_types(self, mock_run_cli): ord_dict = OrderedDict() ord_dict['b'] = [True] ord_dict['i'] = [-1] ord_dict['t'] = [1] ord_dict['f'] = [0.0] ord_dict['c'] = ['c'] ord_dict['M'] = [datetime.datetime(2018, 1, 1)] ord_dict['O'] = [object()] ord_dict['S'] = [b'STRING'] ord_dict['U'] = ['STRING'] ord_dict['V'] = [None] df = pd.DataFrame(ord_dict) hook = MockHiveCliHook() hook.load_df(df, 't') query = """ CREATE TABLE IF NOT EXISTS t ( `b` BOOLEAN, `i` BIGINT, `t` BIGINT, `f` DOUBLE, `c` STRING, `M` TIMESTAMP, `O` STRING, `S` STRING, `U` STRING, `V` STRING) ROW FORMAT DELIMITED FIELDS TERMINATED BY ',' STORED AS textfile ; """ assert_equal_ignore_multiple_spaces(self, mock_run_cli.call_args_list[0][0][0], query)
def test_run_cli(self, mock_popen, mock_temp_dir): mock_subprocess = MockSubProcess() mock_popen.return_value = mock_subprocess mock_temp_dir.return_value = "test_run_cli" with mock.patch.dict('os.environ', { 'AIRFLOW_CTX_DAG_ID': 'test_dag_id', 'AIRFLOW_CTX_TASK_ID': 'test_task_id', 'AIRFLOW_CTX_EXECUTION_DATE': '2015-01-01T00:00:00+00:00', 'AIRFLOW_CTX_DAG_RUN_ID': '55', 'AIRFLOW_CTX_DAG_OWNER': 'airflow', 'AIRFLOW_CTX_DAG_EMAIL': '*****@*****.**', }): hook = MockHiveCliHook() hook.run_cli("SHOW DATABASES") hive_cmd = ['beeline', '-u', '"jdbc:hive2://localhost:10000/default"', '-hiveconf', 'airflow.ctx.dag_id=test_dag_id', '-hiveconf', 'airflow.ctx.task_id=test_task_id', '-hiveconf', 'airflow.ctx.execution_date=2015-01-01T00:00:00+00:00', '-hiveconf', 'airflow.ctx.dag_run_id=55', '-hiveconf', 'airflow.ctx.dag_owner=airflow', '-hiveconf', '[email protected]', '-hiveconf', 'mapreduce.job.queuename=airflow', '-hiveconf', 'mapred.job.queue.name=airflow', '-hiveconf', 'tez.queue.name=airflow', '-f', '/tmp/airflow_hiveop_test_run_cli/tmptest_run_cli'] mock_popen.assert_called_with( hive_cmd, stdout=mock_subprocess.PIPE, stderr=mock_subprocess.STDOUT, cwd="/tmp/airflow_hiveop_test_run_cli", close_fds=True )
def test_load_file_without_create_table(self, mock_run_cli): filepath = "/path/to/input/file" table = "output_table" hook = MockHiveCliHook() hook.load_file(filepath=filepath, table=table, create=False) query = f"LOAD DATA LOCAL INPATH '{filepath}' OVERWRITE INTO TABLE {table} ;\n" calls = [mock.call(query)] mock_run_cli.assert_has_calls(calls, any_order=True)
def test_load_df_with_optional_parameters(self, mock_to_csv, mock_load_file): hook = MockHiveCliHook() bools = (True, False) for create, recreate in itertools.product(bools, bools): mock_load_file.reset_mock() hook.load_df(df=pd.DataFrame({"c": range(0, 10)}), table="t", create=create, recreate=recreate) assert mock_load_file.call_count == 1 kwargs = mock_load_file.call_args[1] self.assertEqual(kwargs["create"], create) self.assertEqual(kwargs["recreate"], recreate)
def test_get_proxy_user_value(self): hook = MockHiveCliHook() returner = mock.MagicMock() returner.extra_dejson = {'proxy_user': '******'} hook.use_beeline = True hook.conn = returner # Run result = hook._prepare_cli_cmd() # Verify self.assertIn('hive.server2.proxy.user=a_user_proxy', result[2])
def test_load_df(self, mock_to_csv, mock_load_file): df = pd.DataFrame({"c": ["foo", "bar", "baz"]}) table = "t" delimiter = "," encoding = "utf-8" hook = MockHiveCliHook() hook.load_df(df=df, table=table, delimiter=delimiter, encoding=encoding) assert mock_to_csv.call_count == 1 kwargs = mock_to_csv.call_args[1] self.assertEqual(kwargs["header"], False) self.assertEqual(kwargs["index"], False) self.assertEqual(kwargs["sep"], delimiter) assert mock_load_file.call_count == 1 kwargs = mock_load_file.call_args[1] self.assertEqual(kwargs["delimiter"], delimiter) self.assertEqual(kwargs["field_dict"], {"c": "STRING"}) self.assertTrue(isinstance(kwargs["field_dict"], OrderedDict)) self.assertEqual(kwargs["table"], table)
def test_run_cli_with_hive_conf(self, mock_popen): hql = ( "set key;\n" "set airflow.ctx.dag_id;\nset airflow.ctx.dag_run_id;\n" "set airflow.ctx.task_id;\nset airflow.ctx.execution_date;\n" ) dag_id_ctx_var_name = AIRFLOW_VAR_NAME_FORMAT_MAPPING['AIRFLOW_CONTEXT_DAG_ID']['env_var_format'] task_id_ctx_var_name = AIRFLOW_VAR_NAME_FORMAT_MAPPING['AIRFLOW_CONTEXT_TASK_ID']['env_var_format'] execution_date_ctx_var_name = AIRFLOW_VAR_NAME_FORMAT_MAPPING['AIRFLOW_CONTEXT_EXECUTION_DATE'][ 'env_var_format' ] dag_run_id_ctx_var_name = AIRFLOW_VAR_NAME_FORMAT_MAPPING['AIRFLOW_CONTEXT_DAG_RUN_ID'][ 'env_var_format' ] mock_output = [ 'Connecting to jdbc:hive2://localhost:10000/default', 'log4j:WARN No appenders could be found for logger (org.apache.hive.jdbc.Utils).', 'log4j:WARN Please initialize the log4j system properly.', 'log4j:WARN See http://logging.apache.org/log4j/1.2/faq.html#noconfig for more info.', 'Connected to: Apache Hive (version 1.2.1.2.3.2.0-2950)', 'Driver: Hive JDBC (version 1.2.1.spark2)', 'Transaction isolation: TRANSACTION_REPEATABLE_READ', '0: jdbc:hive2://localhost:10000/default> USE default;', 'No rows affected (0.37 seconds)', '0: jdbc:hive2://localhost:10000/default> set key;', '+------------+--+', '| set |', '+------------+--+', '| key=value |', '+------------+--+', '1 row selected (0.133 seconds)', '0: jdbc:hive2://localhost:10000/default> set airflow.ctx.dag_id;', '+---------------------------------+--+', '| set |', '+---------------------------------+--+', '| airflow.ctx.dag_id=test_dag_id |', '+---------------------------------+--+', '1 row selected (0.008 seconds)', '0: jdbc:hive2://localhost:10000/default> set airflow.ctx.dag_run_id;', '+-----------------------------------------+--+', '| set |', '+-----------------------------------------+--+', '| airflow.ctx.dag_run_id=test_dag_run_id |', '+-----------------------------------------+--+', '1 row selected (0.007 seconds)', '0: jdbc:hive2://localhost:10000/default> set airflow.ctx.task_id;', '+-----------------------------------+--+', '| set |', '+-----------------------------------+--+', '| airflow.ctx.task_id=test_task_id |', '+-----------------------------------+--+', '1 row selected (0.009 seconds)', '0: jdbc:hive2://localhost:10000/default> set airflow.ctx.execution_date;', '+-------------------------------------------------+--+', '| set |', '+-------------------------------------------------+--+', '| airflow.ctx.execution_date=test_execution_date |', '+-------------------------------------------------+--+', '1 row selected (0.006 seconds)', '0: jdbc:hive2://localhost:10000/default> ', '0: jdbc:hive2://localhost:10000/default> ', 'Closing: 0: jdbc:hive2://localhost:10000/default', '', ] with mock.patch.dict( 'os.environ', { dag_id_ctx_var_name: 'test_dag_id', task_id_ctx_var_name: 'test_task_id', execution_date_ctx_var_name: 'test_execution_date', dag_run_id_ctx_var_name: 'test_dag_run_id', }, ): hook = MockHiveCliHook() mock_popen.return_value = MockSubProcess(output=mock_output) output = hook.run_cli(hql=hql, hive_conf={'key': 'value'}) process_inputs = " ".join(mock_popen.call_args_list[0][0][0]) self.assertIn('value', process_inputs) self.assertIn('test_dag_id', process_inputs) self.assertIn('test_task_id', process_inputs) self.assertIn('test_execution_date', process_inputs) self.assertIn('test_dag_run_id', process_inputs) self.assertIn('value', output) self.assertIn('test_dag_id', output) self.assertIn('test_task_id', output) self.assertIn('test_execution_date', output) self.assertIn('test_dag_run_id', output)