def test_console_extra_link_serialized_field(self): with self.dag: training_op = MLEngineStartTrainingJobOperator(**self.TRAINING_DEFAULT_ARGS) serialized_dag = SerializedDAG.to_dict(self.dag) dag = SerializedDAG.from_dict(serialized_dag) simple_task = dag.task_dict[self.TRAINING_DEFAULT_ARGS['task_id']] # Check Serialized version of operator link self.assertEqual( serialized_dag["dag"]["tasks"][0]["_operator_extra_links"], [{"airflow.providers.google.cloud.operators.mlengine.AIPlatformConsoleLink": {}}] ) # Check DeSerialized version of operator link self.assertIsInstance(list(simple_task.operator_extra_links)[0], AIPlatformConsoleLink) job_id = self.TRAINING_DEFAULT_ARGS['job_id'] project_id = self.TRAINING_DEFAULT_ARGS['project_id'] gcp_metadata = { "job_id": job_id, "project_id": project_id, } ti = TaskInstance( task=training_op, execution_date=DEFAULT_DATE, ) ti.xcom_push(key='gcp_metadata', value=gcp_metadata) self.assertEqual( f"https://console.cloud.google.com/ai-platform/jobs/{job_id}?project={project_id}", simple_task.get_extra_links(DEFAULT_DATE, AIPlatformConsoleLink.name), ) self.assertEqual( '', simple_task.get_extra_links(datetime.datetime(2019, 1, 1), AIPlatformConsoleLink.name), )
def test_set_machine_type_with_templates(self, _): dag_id = 'test_dag_id' args = { 'start_date': DEFAULT_DATE } self.dag = DAG(dag_id, default_args=args) # pylint: disable=attribute-defined-outside-init op = ComputeEngineSetMachineTypeOperator( project_id='{{ dag.dag_id }}', zone='{{ dag.dag_id }}', resource_id='{{ dag.dag_id }}', body={}, gcp_conn_id='{{ dag.dag_id }}', api_version='{{ dag.dag_id }}', task_id='id', dag=self.dag ) ti = TaskInstance(op, DEFAULT_DATE) ti.render_templates() self.assertEqual(dag_id, getattr(op, 'project_id')) self.assertEqual(dag_id, getattr(op, 'zone')) self.assertEqual(dag_id, getattr(op, 'resource_id')) self.assertEqual(dag_id, getattr(op, 'gcp_conn_id')) self.assertEqual(dag_id, getattr(op, 'api_version'))
def setUp(self): super().setUp() self.wasb_log_folder = 'wasb://container/remote/log/location' self.remote_log_location = 'remote/log/location/1.log' self.local_log_location = 'local/log/location' self.container_name = "wasb-container" self.filename_template = '{try_number}.log' self.wasb_task_handler = WasbTaskHandler( base_log_folder=self.local_log_location, wasb_log_folder=self.wasb_log_folder, wasb_container=self.container_name, filename_template=self.filename_template, delete_local_copy=True, ) date = datetime(2020, 8, 10) self.dag = DAG('dag_for_testing_file_task_handler', start_date=date) task = DummyOperator(task_id='task_for_testing_file_log_handler', dag=self.dag) self.ti = TaskInstance(task=task, execution_date=date) self.ti.try_number = 1 self.ti.state = State.RUNNING self.addCleanup(self.dag.clear)
def test_poke_context(self, mock_session_send): response = requests.Response() response.status_code = 200 mock_session_send.return_value = response def resp_check(_, execution_date): if execution_date == DEFAULT_DATE: return True raise AirflowException('AirflowException raised here!') task = HttpSensor( task_id='http_sensor_poke_exception', http_conn_id='http_default', endpoint='', request_params={}, response_check=resp_check, timeout=5, poke_interval=1, dag=self.dag, ) task_instance = TaskInstance(task=task, execution_date=DEFAULT_DATE) task.execute(task_instance.get_template_context())
def test_error_sending_task(self): def fake_execute_command(): pass with _prepare_app(execute=fake_execute_command): # fake_execute_command takes no arguments while execute_command takes 1, # which will cause TypeError when calling task.apply_async() executor = celery_executor.CeleryExecutor() task = BashOperator(task_id="test", bash_command="true", dag=DAG(dag_id='id'), start_date=datetime.datetime.now()) when = datetime.datetime.now() value_tuple = 'command', 1, None, \ SimpleTaskInstance(ti=TaskInstance(task=task, execution_date=datetime.datetime.now())) key = ('fail', 'fake_simple_ti', when, 0) executor.queued_tasks[key] = value_tuple executor.heartbeat() self.assertEqual(0, len(executor.queued_tasks), "Task should no longer be queued") self.assertEqual( executor.event_buffer[('fail', 'fake_simple_ti', when, 0)][0], State.FAILED)
def test_bigquery_operator_extra_link_when_single_query( self, mock_hook, session): bigquery_task = BigQueryExecuteQueryOperator( task_id=TASK_ID, sql='SELECT * FROM test_table', dag=self.dag, ) self.dag.clear() session.query(XCom).delete() ti = TaskInstance( task=bigquery_task, execution_date=DEFAULT_DATE, ) job_id = '12345' ti.xcom_push(key='job_id', value=job_id) assert f'https://console.cloud.google.com/bigquery?j={job_id}' == bigquery_task.get_extra_links( DEFAULT_DATE, BigQueryConsoleLink.name) assert '' == bigquery_task.get_extra_links(datetime(2019, 1, 1), BigQueryConsoleLink.name)
def test_run_airflow_dag(scaffold_dag): '''This test runs the sample Airflow dag using the TaskInstance API, directly from Python''' _n, _p, _d, static_path, editable_path = scaffold_dag execution_date = datetime.datetime.utcnow() import_module_from_path('demo_pipeline_static__scaffold', static_path) demo_pipeline = import_module_from_path('demo_pipeline', editable_path) _dag, tasks = demo_pipeline.make_dag( dag_id=demo_pipeline.DAG_ID, dag_description=demo_pipeline.DAG_DESCRIPTION, dag_kwargs=dict(default_args=demo_pipeline.DEFAULT_ARGS, **demo_pipeline.DAG_KWARGS), s3_conn_id=demo_pipeline.S3_CONN_ID, modified_docker_operator_kwargs=demo_pipeline.MODIFIED_DOCKER_OPERATOR_KWARGS, host_tmp_dir=demo_pipeline.HOST_TMP_DIR, ) # These are in topo order already for task in tasks: ti = TaskInstance(task=task, execution_date=execution_date) context = ti.get_template_context() task.execute(context)
def test_parse_bucket_key_from_jinja(self, mock_hook): mock_hook.return_value.check_for_key.return_value = False Variable.set("test_bucket_key", "s3://bucket/key") execution_date = datetime(2020, 1, 1) dag = DAG("test_s3_key", start_date=execution_date) op = S3KeySensor( task_id='s3_key_sensor', bucket_key='{{ var.value.test_bucket_key }}', bucket_name=None, dag=dag, ) ti = TaskInstance(task=op, execution_date=execution_date) context = ti.get_template_context() ti.render_templates(context) op.poke(None) self.assertEqual(op.bucket_key, "key") self.assertEqual(op.bucket_name, "bucket")
def test_parent_not_executed(): """ A simple DAG with a BranchPythonOperator that does not follow op2. Parent task is not yet executed (no xcom data). NotPreviouslySkippedDep is met (no decision). """ start_date = pendulum.datetime(2020, 1, 1) dag = DAG("test_parent_not_executed_dag", schedule_interval=None, start_date=start_date) op1 = BranchPythonOperator(task_id="op1", python_callable=lambda: "op3", dag=dag) op2 = DummyOperator(task_id="op2", dag=dag) op3 = DummyOperator(task_id="op3", dag=dag) op1 >> [op2, op3] ti2 = TaskInstance(op2, start_date) with create_session() as session: dep = NotPreviouslySkippedDep() assert len(list(dep.get_dep_statuses(ti2, session, DepContext()))) == 0 assert dep.is_met(ti2, session) assert ti2.state == State.NONE
def task_failed_deps(args): """ Returns the unmet dependencies for a task instance from the perspective of the scheduler (i.e. why a task instance doesn't get scheduled and then queued by the scheduler, and then run by an executor). >>> airflow task_failed_deps tutorial sleep 2015-01-01 Task instance dependencies not met: Dagrun Running: Task instance's dagrun did not exist: Unknown reason Trigger Rule: Task's trigger rule 'all_success' requires all upstream tasks to have succeeded, but found 1 non-success(es). """ dag = get_dag(args) task = dag.get_task(task_id=args.task_id) ti = TaskInstance(task, args.execution_date) dep_context = DepContext(deps=SCHEDULER_DEPS) failed_deps = list(ti.get_failed_dep_statuses(dep_context=dep_context)) if failed_deps: print("Task instance dependencies not met:") for dep in failed_deps: print("{}: {}".format(dep.dep_name, dep.reason)) else: print("Task instance dependencies are all met.")
def test_instance_start_with_templates(self, _): dag_id = 'test_dag_id' configuration.load_test_config() args = { 'start_date': DEFAULT_DATE } self.dag = DAG(dag_id, default_args=args) op = GceInstanceStartOperator( project_id='{{ dag.dag_id }}', zone='{{ dag.dag_id }}', resource_id='{{ dag.dag_id }}', gcp_conn_id='{{ dag.dag_id }}', api_version='{{ dag.dag_id }}', task_id='id', dag=self.dag ) ti = TaskInstance(op, DEFAULT_DATE) ti.render_templates() self.assertEqual(dag_id, getattr(op, 'project_id')) self.assertEqual(dag_id, getattr(op, 'zone')) self.assertEqual(dag_id, getattr(op, 'resource_id')) self.assertEqual(dag_id, getattr(op, 'gcp_conn_id')) self.assertEqual(dag_id, getattr(op, 'api_version'))
def test_templates(self, _): dag_id = 'test_dag_id' configuration.load_test_config() args = {'start_date': DEFAULT_DATE} self.dag = DAG(dag_id, default_args=args) op = GoogleCloudStorageToGoogleCloudStorageTransferOperator( source_bucket='{{ dag.dag_id }}', destination_bucket='{{ dag.dag_id }}', description='{{ dag.dag_id }}', object_conditions={'exclude_prefixes': ['{{ dag.dag_id }}']}, gcp_conn_id='{{ dag.dag_id }}', task_id=TASK_ID, dag=self.dag, ) ti = TaskInstance(op, DEFAULT_DATE) ti.render_templates() self.assertEqual(dag_id, getattr(op, 'source_bucket')) self.assertEqual(dag_id, getattr(op, 'destination_bucket')) self.assertEqual(dag_id, getattr(op, 'description')) self.assertEqual( dag_id, getattr(op, 'object_conditions')['exclude_prefixes'][0]) self.assertEqual(dag_id, getattr(op, 'gcp_conn_id'))
def test_bigquery_operator_defaults(self, mock_hook): operator = BigQueryOperator( task_id=TASK_ID, sql='Select * from test_table', dag=self.dag, default_args=self.args, schema_update_options=None ) operator.execute(MagicMock()) mock_hook.return_value \ .get_conn.return_value \ .cursor.return_value \ .run_query \ .assert_called_once_with( sql='Select * from test_table', destination_dataset_table=None, write_disposition='WRITE_EMPTY', allow_large_results=False, flatten_results=None, udf_config=None, maximum_billing_tier=None, maximum_bytes_billed=None, create_disposition='CREATE_IF_NEEDED', schema_update_options=None, query_params=None, labels=None, priority='INTERACTIVE', time_partitioning=None, api_resource_configs=None, cluster_fields=None, encryption_configuration=None ) self.assertTrue(isinstance(operator.sql, str)) ti = TaskInstance(task=operator, execution_date=DEFAULT_DATE) ti.render_templates() self.assertTrue(isinstance(ti.task.sql, str))
def backfill(args): logging.basicConfig(level=settings.LOGGING_LEVEL, format=settings.SIMPLE_LOG_FORMAT) dagbag = DagBag(process_subdir(args.subdir)) if args.dag_id not in dagbag.dags: raise AirflowException('dag_id could not be found') dag = dagbag.dags[args.dag_id] if args.start_date: args.start_date = dateutil.parser.parse(args.start_date) if args.end_date: args.end_date = dateutil.parser.parse(args.end_date) # If only one date is passed, using same as start and end args.end_date = args.end_date or args.start_date args.start_date = args.start_date or args.end_date if args.task_regex: dag = dag.sub_dag(task_regex=args.task_regex, include_upstream=not args.ignore_dependencies) if args.dry_run: print("Dry run of DAG {0} on {1}".format(args.dag_id, args.start_date)) for task in dag.tasks: print("Task {0}".format(task.task_id)) ti = TaskInstance(task, args.start_date) ti.dry_run() else: dag.run(start_date=args.start_date, end_date=args.end_date, mark_success=args.mark_success, include_adhoc=args.include_adhoc, local=args.local, donot_pickle=(args.donot_pickle or conf.getboolean('core', 'donot_pickle')), ignore_dependencies=args.ignore_dependencies, pool=args.pool)
def _solid(context): # pylint: disable=unused-argument if AIRFLOW_EXECUTION_DATE_STR not in context.pipeline_run.tags: raise DagsterInvariantViolationError( 'Could not find "{AIRFLOW_EXECUTION_DATE_STR}" in pipeline tags "{tags}". Please ' 'add "{AIRFLOW_EXECUTION_DATE_STR}" to pipeline tags before executing' .format( AIRFLOW_EXECUTION_DATE_STR=AIRFLOW_EXECUTION_DATE_STR, tags=context.pipeline_run.tags, )) execution_date_str = context.pipeline_run.tags.get( AIRFLOW_EXECUTION_DATE_STR) check.str_param(execution_date_str, 'execution_date_str') try: execution_date = dateutil.parser.parse(execution_date_str) except ValueError: raise DagsterInvariantViolationError( 'Could not parse execution_date "{execution_date_str}". Please use datetime format ' 'compatible with dateutil.parser.parse.'.format( execution_date_str=execution_date_str, )) except OverflowError: raise DagsterInvariantViolationError( 'Date "{execution_date_str}" exceeds the largest valid C integer on the system.' .format(execution_date_str=execution_date_str, )) check.inst_param(execution_date, 'execution_date', datetime.datetime) with replace_airflow_logger_handlers(): task_instance = TaskInstance(task=task, execution_date=execution_date) ti_context = task_instance.get_template_context() task.render_template_fields(ti_context) task.execute(ti_context) return None
def _get_ti( task: BaseOperator, exec_date_or_run_id: str, map_index: int, *, create_if_necessary: CreateIfNecessary = False, session: Session = NEW_SESSION, ) -> Tuple[TaskInstance, bool]: """Get the task instance through DagRun.run_id, if that fails, get the TI the old way""" if task.is_mapped: if map_index < 0: raise RuntimeError("No map_index passed to mapped task") elif map_index >= 0: raise RuntimeError("map_index passed to non-mapped task") dag_run, dr_created = _get_dag_run( dag=task.dag, exec_date_or_run_id=exec_date_or_run_id, create_if_necessary=create_if_necessary, session=session, ) ti_or_none = dag_run.get_task_instance(task.task_id, map_index=map_index, session=session) if ti_or_none is None: if not create_if_necessary: raise TaskInstanceNotFound( f"TaskInstance for {task.dag.dag_id}, {task.task_id}, map={map_index} with " f"run_id or execution_date of {exec_date_or_run_id!r} not found" ) # TODO: Validate map_index is in range? ti = TaskInstance(task, run_id=dag_run.run_id, map_index=map_index) ti.dag_run = dag_run else: ti = ti_or_none ti.refresh_from_task(task) return ti, dr_created
def test_rerun_failed_subdag(self): """ When there is an existing DagRun with failed state, reset the DagRun and the corresponding TaskInstances """ dag = DAG('parent', default_args=default_args) subdag = DAG('parent.test', default_args=default_args) subdag_task = SubDagOperator(task_id='test', subdag=subdag, dag=dag, poke_interval=1) dummy_task = DummyOperator(task_id='dummy', dag=subdag) with create_session() as session: dummy_task_instance = TaskInstance( task=dummy_task, execution_date=DEFAULT_DATE, state=State.FAILED, ) session.add(dummy_task_instance) session.commit() sub_dagrun = subdag.create_dagrun( run_type=DagRunType.SCHEDULED, execution_date=DEFAULT_DATE, state=State.FAILED, external_trigger=True, ) subdag_task._reset_dag_run_and_task_instances( sub_dagrun, execution_date=DEFAULT_DATE) dummy_task_instance.refresh_from_db() assert dummy_task_instance.state == State.NONE sub_dagrun.refresh_from_db() assert sub_dagrun.state == State.RUNNING
def test_render_template(self): json_str = ''' { "type": "{{ params.index_type }}", "datasource": "{{ params.datasource }}", "spec": { "dataSchema": { "granularitySpec": { "intervals": ["{{ ds }}/{{ macros.ds_add(ds, 1) }}"] } } } } ''' operator = DruidOperator(task_id='spark_submit_job', json_index_file=json_str, params={ 'index_type': 'index_hadoop', 'datasource': 'datasource_prd' }, dag=self.dag) ti = TaskInstance(operator, DEFAULT_DATE) ti.render_templates() expected = ''' { "type": "index_hadoop", "datasource": "datasource_prd", "spec": { "dataSchema": { "granularitySpec": { "intervals": ["2017-01-01/2017-01-02"] } } } } ''' self.assertEqual(expected, getattr(operator, 'json_index_file'))
def test_task_states_for_dag_run(self): dag2 = DagBag().dags['example_python_operator'] task2 = dag2.get_task(task_id='print_the_context') defaut_date2 = timezone.make_aware(datetime(2016, 1, 9)) dag2.clear() ti2 = TaskInstance(task2, defaut_date2) ti2.set_state(State.SUCCESS) ti_start = ti2.start_date ti_end = ti2.end_date with redirect_stdout(io.StringIO()) as stdout: task_command.task_states_for_dag_run( self.parser.parse_args([ 'tasks', 'states-for-dag-run', 'example_python_operator', defaut_date2.isoformat() ])) actual_out = stdout.getvalue() formatted_rows = [( 'example_python_operator', '2016-01-09 00:00:00+00:00', 'print_the_context', 'success', ti_start, ti_end, )] expected = tabulate( formatted_rows, ['dag', 'exec_date', 'task', 'state', 'start_date', 'end_date'], tablefmt="plain") # Check that prints, and log messages, are shown self.assertIn(expected.replace("\n", ""), actual_out.replace("\n", ""))
def task_test(args, dag=None): """Tests task for a given dag_id""" # We want log outout from operators etc to show up here. Normally # airflow.task would redirect to a file, but here we want it to propagate # up to the normal airflow handler. handlers = logging.getLogger('airflow.task').handlers already_has_stream_handler = False for handler in handlers: already_has_stream_handler = isinstance(handler, logging.StreamHandler) if already_has_stream_handler: break if not already_has_stream_handler: logging.getLogger('airflow.task').propagate = True dag = dag or get_dag(args) task = dag.get_task(task_id=args.task_id) # Add CLI provided task_params to task.params if args.task_params: passed_in_params = json.loads(args.task_params) task.params.update(passed_in_params) ti = TaskInstance(task, args.execution_date) try: if args.dry_run: ti.dry_run() else: ti.run(ignore_task_deps=True, ignore_ti_state=True, test_mode=True) except Exception: # pylint: disable=broad-except if args.post_mortem: try: debugger = importlib.import_module("ipdb") except ImportError: debugger = importlib.import_module("pdb") debugger.post_mortem() else: raise
def test_task_states_for_dag_run(self): dag2 = DagBag().dags['example_python_operator'] task2 = dag2.get_task(task_id='print_the_context') defaut_date2 = timezone.make_aware(datetime(2016, 1, 9)) dag2.clear() ti2 = TaskInstance(task2, defaut_date2) ti2.set_state(State.SUCCESS) ti_start = ti2.start_date ti_end = ti2.end_date with redirect_stdout(io.StringIO()) as stdout: task_command.task_states_for_dag_run( self.parser.parse_args([ 'tasks', 'states-for-dag-run', 'example_python_operator', defaut_date2.isoformat(), '--output', "json", ])) actual_out = json.loads(stdout.getvalue()) self.assertEqual(len(actual_out), 1) self.assertDictEqual( actual_out[0], { 'dag_id': 'example_python_operator', 'execution_date': '2016-01-09T00:00:00+00:00', 'task_id': 'print_the_context', 'state': 'success', 'start_date': ti_start.isoformat(), 'end_date': ti_end.isoformat(), }, )
def setUp(self): super().setUp() self.local_log_location = 'local/log/location' self.filename_template = '{try_number}.log' self.log_id_template = '{dag_id}-{task_id}-{execution_date}-{try_number}' self.end_of_log_mark = 'end_of_log\n' self.write_stdout = False self.json_format = False self.json_fields = 'asctime,filename,lineno,levelname,message' self.es_task_handler = ElasticsearchTaskHandler( self.local_log_location, self.filename_template, self.log_id_template, self.end_of_log_mark, self.write_stdout, self.json_format, self.json_fields ) self.es = elasticsearch.Elasticsearch( # pylint: disable=invalid-name hosts=[{'host': 'localhost', 'port': 9200}] ) self.index_name = 'test_index' self.doc_type = 'log' self.test_message = 'some random stuff' self.body = {'message': self.test_message, 'log_id': self.LOG_ID, 'offset': 1} self.es.index(index=self.index_name, doc_type=self.doc_type, body=self.body, id=1) self.dag = DAG(self.DAG_ID, start_date=self.EXECUTION_DATE) task = DummyOperator(task_id=self.TASK_ID, dag=self.dag) self.ti = TaskInstance(task=task, execution_date=self.EXECUTION_DATE) self.ti.try_number = 1 self.ti.state = State.RUNNING self.addCleanup(self.dag.clear)
def test_log_file_template_with_run_task(self): """Verify that the taskinstance has the right context for log_filename_template""" with mock.patch.object(task_command, "_run_task_by_selected_method"): with conf_vars({('core', 'dags_folder'): self.dag_path}): # increment the try_number of the task to be run dag = DagBag().get_dag(self.dag_id) task = dag.get_task(self.task_id) with create_session() as session: dag.create_dagrun( execution_date=self.execution_date, start_date=timezone.utcnow(), state=State.RUNNING, run_type=DagRunType.MANUAL, session=session, ) ti = TaskInstance(task, self.execution_date) ti.refresh_from_db(session=session, lock_for_update=True) ti.try_number = 1 # not running, so starts at 0 session.merge(ti) log_file_path = os.path.join( os.path.dirname(self.ti_log_file_path), "2.log") try: task_command.task_run( self.parser.parse_args([ 'tasks', 'run', self.dag_id, self.task_id, '--local', self.execution_date_str ])) assert os.path.exists(log_file_path) finally: try: os.remove(log_file_path) except OSError: pass
def test_file_task_handler_running(self): def task_callable(ti, **kwargs): ti.log.info("test") dag = DAG('dag_for_testing_file_task_handler', start_date=DEFAULT_DATE) task = PythonOperator(task_id='task_for_testing_file_log_handler', dag=dag, python_callable=task_callable, provide_context=True) ti = TaskInstance(task=task, execution_date=DEFAULT_DATE) ti.try_number = 2 ti.state = State.RUNNING logger = ti.log ti.log.disabled = False file_handler = next((handler for handler in logger.handlers if handler.name == FILE_TASK_HANDLER), None) self.assertIsNotNone(file_handler) set_context(logger, ti) self.assertIsNotNone(file_handler.handler) # We expect set_context generates a file locally. log_filename = file_handler.handler.baseFilename self.assertTrue(os.path.isfile(log_filename)) self.assertTrue(log_filename.endswith("2.log"), log_filename) logger.info("Test") # Return value of read must be a list. logs = file_handler.read(ti) self.assertTrue(isinstance(logs, list)) # Logs for running tasks should show up too. self.assertEqual(len(logs), 2) # Remove the generated tmp log file. os.remove(log_filename)
def test_console_extra_link(self, mock_hook): training_op = MLEngineStartTrainingJobOperator( **self.TRAINING_DEFAULT_ARGS) ti = TaskInstance( task=training_op, execution_date=DEFAULT_DATE, ) job_id = self.TRAINING_DEFAULT_ARGS['job_id'] project_id = self.TRAINING_DEFAULT_ARGS['project_id'] gcp_metadata = { "job_id": job_id, "project_id": project_id, } ti.xcom_push(key='gcp_metadata', value=gcp_metadata) assert ( f"https://console.cloud.google.com/ai-platform/jobs/{job_id}?project={project_id}" == training_op.get_extra_links(DEFAULT_DATE, AIPlatformConsoleLink.name)) assert '' == training_op.get_extra_links(datetime.datetime(2019, 1, 1), AIPlatformConsoleLink.name)
def test_templates(self, _): dag_id = 'test_dag_id' args = {'start_date': DEFAULT_DATE} self.dag = DAG(dag_id, default_args=args) # pylint:disable=attribute-defined-outside-init op = GoogleCloudStorageToGoogleCloudStorageTransferOperator( source_bucket='{{ dag.dag_id }}', destination_bucket='{{ dag.dag_id }}', description='{{ dag.dag_id }}', object_conditions={'exclude_prefixes': ['{{ dag.dag_id }}']}, gcp_conn_id='{{ dag.dag_id }}', task_id=TASK_ID, dag=self.dag, ) ti = TaskInstance(op, DEFAULT_DATE) ti.render_templates() self.assertEqual(dag_id, getattr(op, 'source_bucket')) self.assertEqual(dag_id, getattr(op, 'destination_bucket')) self.assertEqual(dag_id, getattr(op, 'description')) # pylint:disable=unsubscriptable-object self.assertEqual(dag_id, getattr(op, 'object_conditions')['exclude_prefixes'][0]) # pylint:enable=unsubscriptable-object self.assertEqual(dag_id, getattr(op, 'gcp_conn_id'))
def test_bigquery_operator_extra_serialized_field_when_single_query(self): with self.dag: BigQueryExecuteQueryOperator( task_id=TASK_ID, sql='SELECT * FROM test_table', ) serialized_dag = SerializedDAG.to_dict(self.dag) self.assertIn("sql", serialized_dag["dag"]["tasks"][0]) dag = SerializedDAG.from_dict(serialized_dag) simple_task = dag.task_dict[TASK_ID] self.assertEqual(getattr(simple_task, "sql"), 'SELECT * FROM test_table') ######################################################### # Verify Operator Links work with Serialized Operator ######################################################### # Check Serialized version of operator link self.assertEqual( serialized_dag["dag"]["tasks"][0]["_operator_extra_links"], [{'airflow.providers.google.cloud.operators.bigquery.BigQueryConsoleLink': {}}] ) # Check DeSerialized version of operator link self.assertIsInstance(list(simple_task.operator_extra_links)[0], BigQueryConsoleLink) ti = TaskInstance(task=simple_task, execution_date=DEFAULT_DATE) ti.xcom_push('job_id', 12345) # check for positive case url = simple_task.get_extra_links(DEFAULT_DATE, BigQueryConsoleLink.name) self.assertEqual(url, 'https://console.cloud.google.com/bigquery?j=12345') # check for negative case url2 = simple_task.get_extra_links(datetime(2017, 1, 2), BigQueryConsoleLink.name) self.assertEqual(url2, '')
def test_render_template_from_file(self): self.operator.job_flow_overrides = 'job.j2.json' self.operator.params = {'releaseLabel': '5.11.0'} ti = TaskInstance(self.operator, DEFAULT_DATE) ti.render_templates() self.emr_client_mock.run_job_flow.return_value = RUN_JOB_FLOW_SUCCESS_RETURN emr_session_mock = MagicMock() emr_session_mock.client.return_value = self.emr_client_mock boto3_session_mock = MagicMock(return_value=emr_session_mock) with patch('boto3.session.Session', boto3_session_mock): self.operator.execute(None) expected_args = { 'Name': 'test_job_flow', 'ReleaseLabel': '5.11.0', 'Steps': [{ 'Name': 'test_step', 'ActionOnFailure': 'CONTINUE', 'HadoopJarStep': { 'Jar': 'command-runner.jar', 'Args': [ '/usr/lib/spark/bin/run-example', '2016-12-31', '2017-01-01', ] } }] } self.assertDictEqual(self.operator.job_flow_overrides, expected_args)
def execute_tasks_in_dag(dag, tasks, run_id, execution_date): assert isinstance(dag, DAG) handler = logging.StreamHandler(sys.stdout) handler.setLevel(logging.DEBUG) handler.setFormatter(logging.Formatter(LOG_FORMAT)) root = logging.getLogger("airflow.task.operators") root.setLevel(logging.DEBUG) root.addHandler(handler) dag_run = dag.create_dagrun(run_id=run_id, state="success", execution_date=execution_date) results = {} for task in tasks: ti = TaskInstance(task=task, execution_date=execution_date) context = ti.get_template_context() context["dag_run"] = dag_run try: results[ti] = task.execute(context) except AirflowSkipException as exc: results[ti] = exc return results
def setUp(self): super(TestLogView, self).setUp() # Make sure that the configure_logging is not cached self.old_modules = dict(sys.modules) # Create a custom logging configuration logging_config = copy.deepcopy(DEFAULT_LOGGING_CONFIG) current_dir = os.path.dirname(os.path.abspath(__file__)) logging_config['handlers']['task'][ 'base_log_folder'] = os.path.normpath( os.path.join(current_dir, 'test_logs')) logging_config['handlers']['task']['filename_template'] = \ '{{ ti.dag_id }}/{{ ti.task_id }}/{{ ts | replace(":", ".") }}/{{ try_number }}.log' # Write the custom logging configuration to a file self.settings_folder = tempfile.mkdtemp() settings_file = os.path.join(self.settings_folder, "airflow_local_settings.py") new_logging_file = "LOGGING_CONFIG = {}".format(logging_config) with open(settings_file, 'w') as handle: handle.writelines(new_logging_file) sys.path.append(self.settings_folder) conf.set('core', 'logging_config_class', 'airflow_local_settings.LOGGING_CONFIG') app = application.create_app(testing=True) self.app = app.test_client() self.session = Session() from airflow.www.views import dagbag dag = DAG(self.DAG_ID, start_date=self.DEFAULT_DATE) task = DummyOperator(task_id=self.TASK_ID, dag=dag) dagbag.bag_dag(dag, parent_dag=dag, root_dag=dag) ti = TaskInstance(task=task, execution_date=self.DEFAULT_DATE) ti.try_number = 1 self.session.merge(ti) self.session.commit()