def test_check_for_named_partition(self): # Check for existing partition. partition = "{p_by}={date}".format(date=DEFAULT_DATE_DS, p_by=self.partition_by) self.hook.metastore.__enter__( ).check_for_named_partition = MagicMock(return_value=True) self.assertTrue( self.hook.check_for_named_partition(self.database, self.table, partition)) self.hook.metastore.__enter__().check_for_named_partition.assert_called_with( self.database, self.table, partition) # Check for non-existent partition missing_partition = "{p_by}={date}".format(date=self.next_day, p_by=self.partition_by) self.hook.metastore.__enter__().check_for_named_partition = MagicMock( return_value=False) self.assertFalse( self.hook.check_for_named_partition(self.database, self.table, missing_partition) ) self.hook.metastore.__enter__().check_for_named_partition.assert_called_with( self.database, self.table, missing_partition)
def test_dynamodb_to_s3_success(self, mock_aws_dynamodb_hook, mock_s3_hook): responses = [ { 'Items': [{'a': 1}, {'b': 2}], 'LastEvaluatedKey': '123', }, { 'Items': [{'c': 3}], }, ] table = MagicMock() table.return_value.scan.side_effect = responses mock_aws_dynamodb_hook.return_value.get_conn.return_value.Table = table s3_client = MagicMock() s3_client.return_value.upload_file = self.mock_upload_file mock_s3_hook.return_value.get_conn = s3_client dynamodb_to_s3_operator = DynamoDBToS3Operator( task_id='dynamodb_to_s3', dynamodb_table_name='airflow_rocks', s3_bucket_name='airflow-bucket', file_size=4000, ) dynamodb_to_s3_operator.execute(context={}) self.assertEqual([{'a': 1}, {'b': 2}, {'c': 3}], self.output_queue)
def test_check_for_partition(self): # Check for existent partition. FakePartition = namedtuple('FakePartition', ['values']) fake_partition = FakePartition(['2015-01-01']) metastore = self.hook.metastore.__enter__() partition = "{p_by}='{date}'".format(date=DEFAULT_DATE_DS, p_by=self.partition_by) metastore.get_partitions_by_filter = MagicMock( return_value=[fake_partition]) self.assertTrue( self.hook.check_for_partition(self.database, self.table, partition) ) metastore.get_partitions_by_filter( self.database, self.table, partition, 1) # Check for non-existent partition. missing_partition = "{p_by}='{date}'".format(date=self.next_day, p_by=self.partition_by) metastore.get_partitions_by_filter = MagicMock(return_value=[]) self.assertFalse( self.hook.check_for_partition(self.database, self.table, missing_partition) ) metastore.get_partitions_by_filter.assert_called_with( self.database, self.table, missing_partition, 1)
def __init__(self, *args, **kwargs): super(MockHiveServer2Hook, self).__init__() self.mock_cursor = kwargs.get('connection_cursor', MockConnectionCursor()) self.mock_cursor.execute = MagicMock() self.get_conn = MagicMock(return_value=self.mock_cursor) self.get_connection = MagicMock(return_value=MockDBConnection({}))
def test_max_partition(self): FakeFieldSchema = namedtuple('FakeFieldSchema', ['name']) fake_schema = FakeFieldSchema('ds') FakeTable = namedtuple('FakeTable', ['partitionKeys']) fake_table = FakeTable([fake_schema]) metastore = self.hook.metastore.__enter__() metastore.get_table = MagicMock(return_value=fake_table) metastore.get_partition_names = MagicMock( return_value=['ds=2015-01-01']) metastore.partition_name_to_spec = MagicMock( return_value={'ds': '2015-01-01'}) filter_map = {self.partition_by: DEFAULT_DATE_DS} partition = self.hook.max_partition(schema=self.database, table_name=self.table, field=self.partition_by, filter_map=filter_map) self.assertEqual(partition, DEFAULT_DATE_DS) metastore.get_table.assert_called_with( dbname=self.database, tbl_name=self.table) metastore.get_partition_names.assert_called_with( self.database, self.table, max_parts=HiveMetastoreHook.MAX_PART_COUNT) metastore.partition_name_to_spec.assert_called_with('ds=2015-01-01')
def __init__(self, extra_dejson=None, *args, **kwargs): self.extra_dejson = extra_dejson self.get_records = MagicMock(return_value=[['test_record']]) output = kwargs.get('output', ['' for _ in range(10)]) self.readline = MagicMock( side_effect=[line.encode() for line in output])
def test_hive_to_mysql(self): test_hive_results = 'test_hive_results' mock_hive_hook = MockHiveServer2Hook() mock_hive_hook.get_records = MagicMock(return_value=test_hive_results) mock_mysql_hook = MockMySqlHook() mock_mysql_hook.run = MagicMock() mock_mysql_hook.insert_rows = MagicMock() with patch('airflow.operators.hive_to_mysql.HiveServer2Hook', return_value=mock_hive_hook): with patch('airflow.operators.hive_to_mysql.MySqlHook', return_value=mock_mysql_hook): op = HiveToMySqlTransfer( mysql_conn_id='airflow_db', task_id='hive_to_mysql_check', sql=""" SELECT name FROM airflow.static_babynames LIMIT 100 """, mysql_table='test_static_babynames', mysql_preoperator=[ 'DROP TABLE IF EXISTS test_static_babynames;', 'CREATE TABLE test_static_babynames (name VARCHAR(500))', ], dag=self.dag) op.clear(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE) op.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE, ignore_ti_state=True) raw_select_name_query = mock_hive_hook.get_records.call_args_list[0][ 0][0] actual_select_name_query = re.sub(r'\s{2,}', ' ', raw_select_name_query).strip() expected_select_name_query = 'SELECT name FROM airflow.static_babynames LIMIT 100' self.assertEqual(expected_select_name_query, actual_select_name_query) actual_hive_conf = mock_hive_hook.get_records.call_args_list[0][1][ 'hive_conf'] expected_hive_conf = { 'airflow.ctx.dag_owner': 'airflow', 'airflow.ctx.dag_id': 'test_dag_id', 'airflow.ctx.task_id': 'hive_to_mysql_check', 'airflow.ctx.execution_date': '2015-01-01T00:00:00+00:00' } self.assertEqual(expected_hive_conf, actual_hive_conf) expected_mysql_preoperator = [ 'DROP TABLE IF EXISTS test_static_babynames;', 'CREATE TABLE test_static_babynames (name VARCHAR(500))' ] mock_mysql_hook.run.assert_called_with(expected_mysql_preoperator) mock_mysql_hook.insert_rows.assert_called_with( table='test_static_babynames', rows=test_hive_results)
def __init__(self, *args, **kwargs): self.conn = MockConnectionCursor() self.conn.execute = MagicMock() self.get_conn = MagicMock(return_value=self.conn) self.get_first = MagicMock(return_value=[['val_0', 'val_1'], 'val_2']) super(MockPrestoHook, self).__init__(*args, **kwargs)
def __init__(self, *args, **kwargs): self.conn = MockConnectionCursor() self.conn.execute = MagicMock() self.get_conn = MagicMock(return_value=self.conn) self.get_records = MagicMock(return_value=[]) self.insert_rows = MagicMock(return_value=True) super(MockMySqlHook, self).__init__(*args, **kwargs)
def test_sync(self, run_task_mock): run_task_mock.return_value = True executor = DebugExecutor() ti1 = MagicMock(key="t1") ti2 = MagicMock(key="t2") executor.tasks_to_run = [ti1, ti2] executor.sync() assert not executor.tasks_to_run run_task_mock.assert_has_calls([mock.call(ti1), mock.call(ti2)])
def __init__(self, *args, **kwargs): super(MockHiveCliHook, self).__init__() self.conn = MockConnectionCursor() self.conn.schema = 'default' self.conn.host = 'localhost' self.conn.port = 10000 self.conn.login = None self.conn.password = None self.conn.execute = MagicMock() self.get_conn = MagicMock(return_value=self.conn) self.get_connection = MagicMock(return_value=MockDBConnection({}))
def test_get_proxy_user_value(self): hook = MockHiveCliHook() returner = MagicMock() returner.extra_dejson = {'proxy_user': '******'} hook.use_beeline = True hook.conn = returner # Run result = hook._prepare_cli_cmd() # Verify self.assertIn('hive.server2.proxy.user=a_user_proxy', result[2])
def test_table_exists(self): # Test with existent table. self.hook.metastore.__enter__().get_table = MagicMock(return_value=True) self.assertTrue(self.hook.table_exists(self.table, db=self.database)) self.hook.metastore.__enter__().get_table.assert_called_with( dbname='airflow', tbl_name='static_babynames_partitioned') # Test with non-existent table. self.hook.metastore.__enter__().get_table = MagicMock(side_effect=Exception()) self.assertFalse( self.hook.table_exists("does-not-exist") )
def test_write_temp_file(self): task_id = "some_test_id" sql = "some_sql" sql_params = {':p_data': "2018-01-01"} oracle_conn_id = "oracle_conn_id" filename = "some_filename" azure_data_lake_conn_id = 'azure_data_lake_conn_id' azure_data_lake_path = 'azure_data_lake_path' delimiter = '|' encoding = 'utf-8' cursor_description = [('id', "<class 'cx_Oracle.NUMBER'>", 39, None, 38, 0, 0), ('description', "<class 'cx_Oracle.STRING'>", 60, 240, None, None, 1)] cursor_rows = [[1, 'description 1'], [2, 'description 2']] mock_cursor = MagicMock() mock_cursor.description = cursor_description mock_cursor.__iter__.return_value = cursor_rows op = OracleToAzureDataLakeTransfer( task_id=task_id, filename=filename, oracle_conn_id=oracle_conn_id, sql=sql, sql_params=sql_params, azure_data_lake_conn_id=azure_data_lake_conn_id, azure_data_lake_path=azure_data_lake_path, delimiter=delimiter, encoding=encoding) with TemporaryDirectory(prefix='airflow_oracle_to_azure_op_') as temp: op._write_temp_file(mock_cursor, os.path.join(temp, filename)) assert os.path.exists(os.path.join(temp, filename)) == 1 with open(os.path.join(temp, filename), 'rb') as csvfile: temp_file = csv.reader(csvfile, delimiter=delimiter, encoding=encoding) rownum = 0 for row in temp_file: if rownum == 0: self.assertEqual(row[0], 'id') self.assertEqual(row[1], 'description') else: self.assertEqual(row[0], str(cursor_rows[rownum - 1][0])) self.assertEqual(row[1], cursor_rows[rownum - 1][1]) rownum = rownum + 1
def test_trigger_tasks(self): execute_async_mock = MagicMock() executor = DebugExecutor() executor.execute_async = execute_async_mock executor.queued_tasks = { "t1": (None, 1, None, MagicMock(key="t1")), "t2": (None, 2, None, MagicMock(key="t2")), } executor.trigger_tasks(open_slots=4) assert not executor.queued_tasks assert len(executor.running) == 2 assert len(executor.tasks_to_run) == 2 assert not execute_async_mock.called
def test_fake_socket_passes_through_fileno(): import socket with httpretty.enabled(): s = socket.socket(socket.AF_INET, socket.SOCK_STREAM) s.truesock = MagicMock() expect(s.fileno).called_with().should_not.throw(AttributeError) s.truesock.fileno.assert_called_with()
def setUp(self): self._upload_dataframe() args = {'owner': 'airflow', 'start_date': DEFAULT_DATE} self.dag = DAG('test_dag_id', default_args=args) self.database = 'airflow' self.table = 'hive_server_hook' self.hql = """ CREATE DATABASE IF NOT EXISTS {{ params.database }}; USE {{ params.database }}; DROP TABLE IF EXISTS {{ params.table }}; CREATE TABLE IF NOT EXISTS {{ params.table }} ( a int, b int) ROW FORMAT DELIMITED FIELDS TERMINATED BY ','; LOAD DATA LOCAL INPATH '{{ params.csv_path }}' OVERWRITE INTO TABLE {{ params.table }}; """ self.columns = ['{}.a'.format(self.table), '{}.b'.format(self.table)] with patch('airflow.hooks.hive_hooks.HiveMetastoreHook.get_metastore_client') \ as get_metastore_mock: get_metastore_mock.return_value = MagicMock() self.hook = HiveMetastoreHook()
def test_execute_bad_type(self, mock_hook): operator = BigQueryOperator( task_id=TASK_ID, sql=1, destination_dataset_table=None, write_disposition='WRITE_EMPTY', allow_large_results=False, flatten_results=None, bigquery_conn_id='google_cloud_default', udf_config=None, use_legacy_sql=True, maximum_billing_tier=None, maximum_bytes_billed=None, create_disposition='CREATE_IF_NEEDED', schema_update_options=(), query_params=None, labels=None, priority='INTERACTIVE', time_partitioning=None, api_resource_configs=None, cluster_fields=None, ) with self.assertRaises(AirflowException): operator.execute(MagicMock())
def test_bigquery_operator_defaults(self, mock_hook): operator = BigQueryOperator(task_id=TASK_ID, sql='Select * from test_table', dag=self.dag, default_args=self.args) operator.execute(MagicMock()) mock_hook.return_value \ .get_conn.return_value \ .cursor.return_value \ .run_query \ .assert_called_once_with( sql='Select * from test_table', destination_dataset_table=None, write_disposition='WRITE_EMPTY', allow_large_results=False, flatten_results=None, udf_config=None, maximum_billing_tier=None, maximum_bytes_billed=None, create_disposition='CREATE_IF_NEEDED', schema_update_options=(), query_params=None, labels=None, priority='INTERACTIVE', time_partitioning=None, api_resource_configs=None, cluster_fields=None, ) self.assertTrue(isinstance(operator.sql, six.string_types)) ti = TaskInstance(task=operator, execution_date=DEFAULT_DATE) ti.render_templates() self.assertTrue(isinstance(ti.task.sql, six.string_types))
def test_get_conn(self): with patch('airflow.hooks.hive_hooks.HiveMetastoreHook._find_valid_server') \ as find_valid_server: find_valid_server.return_value = MagicMock(return_value={}) metastore_hook = HiveMetastoreHook() self.assertIsInstance(metastore_hook.get_conn(), HMSClient)
def setUpClass(cls): from tests.compat import MagicMock from airflow.jobs import SchedulerJob cls.dag = DAG( 'test_dag', default_args={ 'owner': 'airflow', 'start_date': DEFAULT_DATE}, schedule_interval=INTERVAL) cls.dag.create_dagrun( run_id="manual__1", execution_date=DEFAULT_DATE, state=State.RUNNING ) cls.dag.create_dagrun( run_id="manual__2", execution_date=timezone.datetime(2016, 1, 1, 12), state=State.RUNNING ) cls.dag.create_dagrun( run_id="manual__3", execution_date=END_DATE, state=State.RUNNING ) cls.dag_file_processor = SchedulerJob(dag_ids=[], log=MagicMock())
def test_pod_mutation_hook(self): """ Tests that pods are mutated by the pod_mutation_hook function in airflow_local_settings. """ with SettingsContext(SETTINGS_FILE_POD_MUTATION_HOOK, "airflow_local_settings"): from airflow import settings settings.import_local_settings() # pylint: ignore pod = MagicMock() pod.volumes = [] settings.pod_mutation_hook(pod) assert pod.namespace == 'airflow-tests' self.assertEqual(pod.volumes[0].name, "bar")
def test_fake_socket_passes_through_gettimeout(): import socket HTTPretty.enable() s = socket.socket(socket.AF_INET, socket.SOCK_STREAM) s.truesock = MagicMock() expect(s.gettimeout).called_with().should_not.throw(AttributeError) s.truesock.gettimeout.assert_called_with()
def test_write_temp_file(self): task_id = "some_test_id" sql = "some_sql" sql_params = {':p_data': "2018-01-01"} oracle_conn_id = "oracle_conn_id" filename = "some_filename" azure_data_lake_conn_id = 'azure_data_lake_conn_id' azure_data_lake_path = 'azure_data_lake_path' delimiter = '|' encoding = 'utf-8' cursor_description = [ ('id', "<class 'cx_Oracle.NUMBER'>", 39, None, 38, 0, 0), ('description', "<class 'cx_Oracle.STRING'>", 60, 240, None, None, 1) ] cursor_rows = [[1, 'description 1'], [2, 'description 2']] mock_cursor = MagicMock() mock_cursor.description = cursor_description mock_cursor.__iter__.return_value = cursor_rows op = OracleToAzureDataLakeTransfer( task_id=task_id, filename=filename, oracle_conn_id=oracle_conn_id, sql=sql, sql_params=sql_params, azure_data_lake_conn_id=azure_data_lake_conn_id, azure_data_lake_path=azure_data_lake_path, delimiter=delimiter, encoding=encoding) with TemporaryDirectory(prefix='airflow_oracle_to_azure_op_') as temp: op._write_temp_file(mock_cursor, os.path.join(temp, filename)) assert os.path.exists(os.path.join(temp, filename)) == 1 with open(os.path.join(temp, filename), 'rb') as csvfile: temp_file = csv.reader(csvfile, delimiter=delimiter, encoding=encoding) rownum = 0 for row in temp_file: if rownum == 0: self.assertEqual(row[0], 'id') self.assertEqual(row[1], 'description') else: self.assertEqual(row[0], str(cursor_rows[rownum - 1][0])) self.assertEqual(row[1], cursor_rows[rownum - 1][1]) rownum = rownum + 1
def test_dataflow_job_init_without_job_id(self): mock_jobs = MagicMock() self.mock_dataflow.projects.return_value.locations.return_value.\ jobs.return_value = mock_jobs _DataflowJob(self.mock_dataflow, TEST_PROJECT, TEST_JOB_NAME, TEST_LOCATION, 10) mock_jobs.list.assert_called_with(projectId=TEST_PROJECT, location=TEST_LOCATION)
def test_execute(): oracle_destination_conn_id = 'oracle_destination_conn_id' destination_table = 'destination_table' oracle_source_conn_id = 'oracle_source_conn_id' source_sql = "select sysdate from dual where trunc(sysdate) = :p_data" source_sql_params = {':p_data': "2018-01-01"} rows_chunk = 5000 cursor_description = [('id', "<class 'cx_Oracle.NUMBER'>", 39, None, 38, 0, 0), ('description', "<class 'cx_Oracle.STRING'>", 60, 240, None, None, 1)] cursor_rows = [[1, 'description 1'], [2, 'description 2']] mock_dest_hook = MagicMock() mock_src_hook = MagicMock() mock_src_conn = mock_src_hook.get_conn.return_value.__enter__.return_value mock_cursor = mock_src_conn.cursor.return_value mock_cursor.description.__iter__.return_value = cursor_description mock_cursor.fetchmany.side_effect = [cursor_rows, []] op = OracleToOracleTransfer( task_id='copy_data', oracle_destination_conn_id=oracle_destination_conn_id, destination_table=destination_table, oracle_source_conn_id=oracle_source_conn_id, source_sql=source_sql, source_sql_params=source_sql_params, rows_chunk=rows_chunk) op._execute(mock_src_hook, mock_dest_hook, None) assert mock_src_hook.get_conn.called assert mock_src_conn.cursor.called mock_cursor.execute.assert_called_once_with(source_sql, source_sql_params) calls = [ mock.call(rows_chunk), mock.call(rows_chunk), ] mock_cursor.fetchmany.assert_has_calls(calls) mock_dest_hook.bulk_insert_rows.assert_called_once_with( destination_table, cursor_rows, commit_every=rows_chunk, target_fields=['id', 'description'])
def test_fake_socket_passes_through_shutdown(): import socket HTTPretty.enable() s = socket.socket(socket.AF_INET, socket.SOCK_STREAM) s.truesock = MagicMock() expect(s.shutdown).called_with( socket.SHUT_RD).should_not.throw(AttributeError) s.truesock.shutdown.assert_called_with(socket.SHUT_RD)
def test_fake_socket_passes_through_bind(): import socket HTTPretty.enable() s = socket.socket(socket.AF_INET, socket.SOCK_STREAM) s.truesock = MagicMock() expect(s.bind).called_with( ('127.0.0.1', 1000)).should_not.throw(AttributeError) s.truesock.bind.assert_called_with(('127.0.0.1', 1000))
def test_get_tables(self): # static_babynames_partitioned self.hook.metastore.__enter__().get_tables = MagicMock( return_value=['static_babynames_partitioned']) self.hook.get_tables(db=self.database, pattern=self.table + "*") self.hook.metastore.__enter__().get_tables.assert_called_with( db_name='airflow', pattern='static_babynames_partitioned*') self.hook.metastore.__enter__().get_table_objects_by_name.assert_called_with( 'airflow', ['static_babynames_partitioned'])
def test_dataflow_job_init_with_job_id(self): mock_jobs = MagicMock() self.mock_dataflow.projects.return_value.locations.return_value. \ jobs.return_value = mock_jobs _DataflowJobsController(self.mock_dataflow, TEST_PROJECT, TEST_JOB_NAME, TEST_LOCATION, 10, TEST_JOB_ID) mock_jobs.get.assert_called_once_with(projectId=TEST_PROJECT, location=TEST_LOCATION, jobId=TEST_JOB_ID)
def test_end(self): ti = MagicMock(key="ti_key") executor = DebugExecutor() executor.tasks_to_run = [ti] executor.running = {ti.key: mock.MagicMock} executor.end() ti.set_state.assert_called_once_with(State.UPSTREAM_FAILED) assert not executor.running
def test_dataflow_wait_for_done_logging(self, mock_select, mock_popen, mock_logging): mock_logging.info = MagicMock() mock_logging.warning = MagicMock() mock_proc = MagicMock() mock_proc.stderr = MagicMock() mock_proc.stderr.readlines = MagicMock(return_value=['test\n', 'error\n']) mock_stderr_fd = MagicMock() mock_proc.stderr.fileno = MagicMock(return_value=mock_stderr_fd) mock_proc_poll = MagicMock() mock_select.return_value = [[mock_stderr_fd]] def poll_resp_error(): mock_proc.return_code = 1 return True mock_proc_poll.side_effect = [None, poll_resp_error] mock_proc.poll = mock_proc_poll mock_popen.return_value = mock_proc dataflow = _Dataflow(['test', 'cmd']) mock_logging.info.assert_called_with('Running command: %s', 'test cmd') self.assertRaises(Exception, dataflow.wait_for_done)
def test_execute(self, mock_data_lake_hook, mock_oracle_hook): task_id = "some_test_id" sql = "some_sql" sql_params = {':p_data': "2018-01-01"} oracle_conn_id = "oracle_conn_id" filename = "some_filename" azure_data_lake_conn_id = 'azure_data_lake_conn_id' azure_data_lake_path = 'azure_data_lake_path' delimiter = '|' encoding = 'latin-1' cursor_description = [ ('id', "<class 'cx_Oracle.NUMBER'>", 39, None, 38, 0, 0), ('description', "<class 'cx_Oracle.STRING'>", 60, 240, None, None, 1) ] cursor_rows = [[1, 'description 1'], [2, 'description 2']] cursor_mock = MagicMock() cursor_mock.description.return_value = cursor_description cursor_mock.__iter__.return_value = cursor_rows mock_oracle_conn = MagicMock() mock_oracle_conn.cursor().return_value = cursor_mock mock_oracle_hook.get_conn().return_value = mock_oracle_conn op = OracleToAzureDataLakeTransfer( task_id=task_id, filename=filename, oracle_conn_id=oracle_conn_id, sql=sql, sql_params=sql_params, azure_data_lake_conn_id=azure_data_lake_conn_id, azure_data_lake_path=azure_data_lake_path, delimiter=delimiter, encoding=encoding) op.execute(None) mock_oracle_hook.assert_called_once_with(oracle_conn_id=oracle_conn_id) mock_data_lake_hook.assert_called_once_with( azure_data_lake_conn_id=azure_data_lake_conn_id)