def backfill(args): logging.basicConfig(level=settings.LOGGING_LEVEL, format=settings.SIMPLE_LOG_FORMAT) dagbag = DagBag(args.subdir) if args.dag_id not in dagbag.dags: raise AirflowException("dag_id could not be found") dag = dagbag.dags[args.dag_id] if args.start_date: args.start_date = dateutil.parser.parse(args.start_date) if args.end_date: args.end_date = dateutil.parser.parse(args.end_date) # If only one date is passed, using same as start and end args.end_date = args.end_date or args.start_date args.start_date = args.start_date or args.end_date if args.task_regex: dag = dag.sub_dag(task_regex=args.task_regex, include_upstream=not args.ignore_dependencies) if args.dry_run: print("Dry run of DAG {0} on {1}".format(args.dag_id, args.start_date)) for task in dag.tasks: print("Task {0}".format(task.task_id)) ti = TaskInstance(task, args.start_date) ti.dry_run() else: dag.run( start_date=args.start_date, end_date=args.end_date, mark_success=args.mark_success, include_adhoc=args.include_adhoc, local=args.local, donot_pickle=(args.donot_pickle or conf.getboolean("core", "donot_pickle")), ignore_dependencies=args.ignore_dependencies, )
def test_file_task_handler(self): dag = DAG('dag_for_testing_file_task_handler', start_date=DEFAULT_DATE) task = DummyOperator(task_id='task_for_testing_file_log_handler', dag=dag) ti = TaskInstance(task=task, execution_date=DEFAULT_DATE) logger = logging.getLogger(TASK_LOGGER) file_handler = next((handler for handler in logger.handlers if handler.name == FILE_TASK_HANDLER), None) self.assertIsNotNone(file_handler) file_handler.set_context(ti) self.assertIsNotNone(file_handler.handler) # We expect set_context generates a file locally. log_filename = file_handler.handler.baseFilename self.assertTrue(os.path.isfile(log_filename)) logger.info("test") ti.run() self.assertTrue(hasattr(file_handler, 'read')) # Return value of read must be a list. logs = file_handler.read(ti) self.assertTrue(isinstance(logs, list)) self.assertEqual(len(logs), 1) # Remove the generated tmp log file. os.remove(log_filename)
def setUp(self): super(TestLogView, self).setUp() # Create a custom logging configuration configuration.load_test_config() logging_config = copy.deepcopy(DEFAULT_LOGGING_CONFIG) current_dir = os.path.dirname(os.path.abspath(__file__)) logging_config['handlers']['task']['base_log_folder'] = os.path.normpath( os.path.join(current_dir, 'test_logs')) logging_config['handlers']['task']['filename_template'] = \ '{{ ti.dag_id }}/{{ ti.task_id }}/{{ ts | replace(":", ".") }}/{{ try_number }}.log' # Write the custom logging configuration to a file self.settings_folder = tempfile.mkdtemp() settings_file = os.path.join(self.settings_folder, "airflow_local_settings.py") new_logging_file = "LOGGING_CONFIG = {}".format(logging_config) with open(settings_file, 'w') as handle: handle.writelines(new_logging_file) sys.path.append(self.settings_folder) conf.set('core', 'logging_config_class', 'airflow_local_settings.LOGGING_CONFIG') app = application.create_app(testing=True) self.app = app.test_client() self.session = Session() from airflow.www.views import dagbag dag = DAG(self.DAG_ID, start_date=self.DEFAULT_DATE) task = DummyOperator(task_id=self.TASK_ID, dag=dag) dagbag.bag_dag(dag, parent_dag=dag, root_dag=dag) ti = TaskInstance(task=task, execution_date=self.DEFAULT_DATE) ti.try_number = 1 self.session.merge(ti) self.session.commit()
def test_check_task_dependencies( self, trigger_rule, successes, skipped, failed, upstream_failed, done, flag_upstream_failed, expect_state, expect_completed, ): start_date = datetime.datetime(2016, 2, 1, 0, 0, 0) dag = models.DAG("test-dag", start_date=start_date) downstream = DummyOperator(task_id="downstream", dag=dag, owner="airflow", trigger_rule=trigger_rule) for i in range(5): task = DummyOperator(task_id="runme_{}".format(i), dag=dag, owner="airflow") task.set_downstream(downstream) run_date = task.start_date + datetime.timedelta(days=5) ti = TI(downstream, run_date) completed = ti.evaluate_trigger_rule( successes=successes, skipped=skipped, failed=failed, upstream_failed=upstream_failed, done=done, flag_upstream_failed=flag_upstream_failed, ) self.assertEqual(completed, expect_completed) self.assertEqual(ti.state, expect_state)
def test_file_transfer_no_intermediate_dir_error_put(self): configuration.conf.set("core", "enable_xcom_pickling", "True") test_local_file_content = \ b"This is local file content \n which is multiline " \ b"continuing....with other character\nanother line here \n this is last line" # create a test file locally with open(self.test_local_filepath, 'wb') as f: f.write(test_local_file_content) # Try to put test file to remote # This should raise an error with "No such file" as the directory # does not exist with self.assertRaises(Exception) as error: put_test_task = SFTPOperator( task_id="test_sftp", ssh_hook=self.hook, local_filepath=self.test_local_filepath, remote_filepath=self.test_remote_filepath_int_dir, operation=SFTPOperation.PUT, create_intermediate_dirs=False, dag=self.dag ) self.assertIsNotNone(put_test_task) ti2 = TaskInstance(task=put_test_task, execution_date=timezone.utcnow()) ti2.run() self.assertIn('No such file', str(error.exception))
def test_backfill_pooled_tasks(self): """ Test that queued tasks are executed by BackfillJob Test for https://github.com/airbnb/airflow/pull/1225 """ session = settings.Session() pool = Pool(pool='test_backfill_pooled_task_pool', slots=1) session.add(pool) session.commit() dag = self.dagbag.get_dag('test_backfill_pooled_task_dag') dag.clear() job = BackfillJob( dag=dag, start_date=DEFAULT_DATE, end_date=DEFAULT_DATE) # run with timeout because this creates an infinite loop if not # caught with timeout(seconds=30): job.run() ti = TI( task=dag.get_task('test_backfill_pooled_task'), execution_date=DEFAULT_DATE) ti.refresh_from_db() self.assertEqual(ti.state, State.SUCCESS)
def test_bigquery_operator_extra_link(self, mock_hook): bigquery_task = BigQueryOperator( task_id=TASK_ID, sql='SELECT * FROM test_table', dag=self.dag, ) self.dag.clear() ti = TaskInstance( task=bigquery_task, execution_date=DEFAULT_DATE, ) job_id = '12345' ti.xcom_push(key='job_id', value=job_id) self.assertEquals( 'https://console.cloud.google.com/bigquery?j={job_id}'.format(job_id=job_id), bigquery_task.get_extra_links(DEFAULT_DATE, BigQueryConsoleLink.name), ) self.assertEquals( '', bigquery_task.get_extra_links(datetime(2019, 1, 1), BigQueryConsoleLink.name), )
def test_cli_backfill_depends_on_past(self): """ Test that CLI respects -I argument """ dag_id = 'test_dagrun_states_deadlock' run_date = DEFAULT_DATE + datetime.timedelta(days=1) args = [ 'backfill', dag_id, '-l', '-s', run_date.isoformat(), ] dag = self.dagbag.get_dag(dag_id) dag.clear() self.assertRaisesRegexp( AirflowException, 'BackfillJob is deadlocked', cli.backfill, self.parser.parse_args(args)) cli.backfill(self.parser.parse_args(args + ['-I'])) ti = TI(dag.get_task('test_depends_on_past'), run_date) ti.refresh_from_db() # task ran self.assertEqual(ti.state, State.SUCCESS) dag.clear()
def test_xcom_push_flag(self): """ Tests the option for Operators to push XComs """ value = 'hello' task_id = 'test_no_xcom_push' dag = models.DAG(dag_id='test_xcom') # nothing saved to XCom task = PythonOperator( task_id=task_id, dag=dag, python_callable=lambda: value, do_xcom_push=False, owner='airflow', start_date=datetime.datetime(2017, 1, 1) ) ti = TI(task=task, execution_date=datetime.datetime(2017, 1, 1)) ti.run() self.assertEqual( ti.xcom_pull( task_ids=task_id, key=models.XCOM_RETURN_KEY ), None )
def test_email_alert_with_config(self, mock_send_email): dag = models.DAG(dag_id='test_failure_email') task = BashOperator( task_id='test_email_alert_with_config', dag=dag, bash_command='exit 1', start_date=DEFAULT_DATE, email='to') ti = TI( task=task, execution_date=datetime.datetime.now()) configuration.set('email', 'SUBJECT_TEMPLATE', '/subject/path') configuration.set('email', 'HTML_CONTENT_TEMPLATE', '/html_content/path') opener = mock_open(read_data='template: {{ti.task_id}}') with patch('airflow.models.taskinstance.open', opener, create=True): try: ti.run() except AirflowException: pass (email, title, body), _ = mock_send_email.call_args self.assertEqual(email, 'to') self.assertEqual('template: test_email_alert_with_config', title) self.assertEqual('template: test_email_alert_with_config', body)
def backfill(args, dag=None): logging.basicConfig(level=settings.LOGGING_LEVEL, format=settings.SIMPLE_LOG_FORMAT) dag = dag or get_dag(args) if not args.start_date and not args.end_date: raise AirflowException("Provide a start_date and/or end_date") # If only one date is passed, using same as start and end args.end_date = args.end_date or args.start_date args.start_date = args.start_date or args.end_date if args.task_regex: dag = dag.sub_dag(task_regex=args.task_regex, include_upstream=not args.ignore_dependencies) if args.dry_run: print("Dry run of DAG {0} on {1}".format(args.dag_id, args.start_date)) for task in dag.tasks: print("Task {0}".format(task.task_id)) ti = TaskInstance(task, args.start_date) ti.dry_run() else: dag.run( start_date=args.start_date, end_date=args.end_date, mark_success=args.mark_success, include_adhoc=args.include_adhoc, local=args.local, donot_pickle=(args.donot_pickle or conf.getboolean("core", "donot_pickle")), ignore_dependencies=args.ignore_dependencies, ignore_first_depends_on_past=args.ignore_first_depends_on_past, pool=args.pool, )
def test_post_execute_hook(self): """ Test that post_execute hook is called with the Operator's result. The result ('error') will cause an error to be raised and trapped. """ class TestError(Exception): pass class TestOperator(PythonOperator): def post_execute(self, context, result): if result == 'error': raise TestError('expected error.') dag = models.DAG(dag_id='test_post_execute_dag') task = TestOperator( task_id='test_operator', dag=dag, python_callable=lambda: 'error', owner='airflow', start_date=datetime.datetime(2017, 2, 1)) ti = TI(task=task, execution_date=datetime.datetime.now()) with self.assertRaises(TestError): ti.run()
def test_clear_task_instances_without_task(self): dag = DAG('test_clear_task_instances_without_task', start_date=DEFAULT_DATE, end_date=DEFAULT_DATE + datetime.timedelta(days=10)) task0 = DummyOperator(task_id='task0', owner='test', dag=dag) task1 = DummyOperator(task_id='task1', owner='test', dag=dag, retries=2) ti0 = TI(task=task0, execution_date=DEFAULT_DATE) ti1 = TI(task=task1, execution_date=DEFAULT_DATE) ti0.run() ti1.run() # Remove the task from dag. dag.task_dict = {} self.assertFalse(dag.has_task(task0.task_id)) self.assertFalse(dag.has_task(task1.task_id)) session = settings.Session() qry = session.query(TI).filter( TI.dag_id == dag.dag_id).all() clear_task_instances(qry, session) session.commit() # When dag is None, max_tries will be maximum of original max_tries or try_number. ti0.refresh_from_db() ti1.refresh_from_db() # Next try to run will be try 2 self.assertEqual(ti0.try_number, 2) self.assertEqual(ti0.max_tries, 1) self.assertEqual(ti1.try_number, 2) self.assertEqual(ti1.max_tries, 2)
def test_scheduler_pooled_tasks(self): """ Test that the scheduler handles queued tasks correctly See issue #1299 """ session = settings.Session() if not ( session.query(Pool) .filter(Pool.pool == 'test_queued_pool') .first()): pool = Pool(pool='test_queued_pool', slots=5) session.merge(pool) session.commit() session.close() dag_id = 'test_scheduled_queued_tasks' dag = self.dagbag.get_dag(dag_id) dag.clear() scheduler = SchedulerJob(dag_id, num_runs=10) scheduler.run() task_1 = dag.tasks[0] ti = TI(task_1, dag.start_date) ti.refresh_from_db() self.assertEqual(ti.state, State.FAILED) dag.clear()
def test_check_and_change_state_before_execution_dep_not_met(self): dag = models.DAG(dag_id='test_check_and_change_state_before_execution') task = DummyOperator(task_id='task', dag=dag, start_date=DEFAULT_DATE) task2 = DummyOperator(task_id='task2', dag=dag, start_date=DEFAULT_DATE) task >> task2 ti = TI( task=task2, execution_date=timezone.utcnow()) self.assertFalse(ti._check_and_change_state_before_execution())
def test_overwrite_params_with_dag_run_none(self): task = DummyOperator(task_id='op') ti = TI(task=task, execution_date=datetime.datetime.now()) params = {"override": False} ti.overwrite_params_with_dag_run_conf(params, None) self.assertEqual(False, params["override"])
def test_s3_to_sftp_operation(self): # Setting configuration.conf.set("core", "enable_xcom_pickling", "True") test_remote_file_content = \ "This is remote file content \n which is also multiline " \ "another line here \n this is last line. EOF" # Test for creation of s3 bucket conn = boto3.client('s3') conn.create_bucket(Bucket=self.s3_bucket) self.assertTrue((self.s3_hook.check_for_bucket(self.s3_bucket))) with open(LOCAL_FILE_PATH, 'w') as f: f.write(test_remote_file_content) self.s3_hook.load_file(LOCAL_FILE_PATH, self.s3_key, bucket_name=BUCKET) # Check if object was created in s3 objects_in_dest_bucket = conn.list_objects(Bucket=self.s3_bucket, Prefix=self.s3_key) # there should be object found, and there should only be one object found self.assertEqual(len(objects_in_dest_bucket['Contents']), 1) # the object found should be consistent with dest_key specified earlier self.assertEqual(objects_in_dest_bucket['Contents'][0]['Key'], self.s3_key) # get remote file to local run_task = S3ToSFTPOperator( s3_bucket=BUCKET, s3_key=S3_KEY, sftp_path=SFTP_PATH, sftp_conn_id=SFTP_CONN_ID, s3_conn_id=S3_CONN_ID, task_id=TASK_ID, dag=self.dag ) self.assertIsNotNone(run_task) run_task.execute(None) # Check that the file is created remotely check_file_task = SSHOperator( task_id="test_check_file", ssh_hook=self.hook, command="cat {0}".format(self.sftp_path), do_xcom_push=True, dag=self.dag ) self.assertIsNotNone(check_file_task) ti3 = TaskInstance(task=check_file_task, execution_date=timezone.utcnow()) ti3.run() self.assertEqual( ti3.xcom_pull(task_ids='test_check_file', key='return_value').strip(), test_remote_file_content.encode('utf-8')) # Clean up after finishing with test conn.delete_object(Bucket=self.s3_bucket, Key=self.s3_key) conn.delete_bucket(Bucket=self.s3_bucket) self.assertFalse((self.s3_hook.check_for_bucket(self.s3_bucket)))
def test_overwrite_params_with_dag_run_conf(self): task = DummyOperator(task_id='op') ti = TI(task=task, execution_date=datetime.datetime.now()) dag_run = DagRun() dag_run.conf = {"override": True} params = {"override": False} ti.overwrite_params_with_dag_run_conf(params, dag_run) self.assertEqual(True, params["override"])
def test(args): log_to_stdout() args.execution_date = dateutil.parser.parse(args.execution_date) dagbag = DagBag(args.subdir) if args.dag_id not in dagbag.dags: raise AirflowException('dag_id could not be found') dag = dagbag.dags[args.dag_id] task = dag.get_task(task_id=args.task_id) ti = TaskInstance(task, args.execution_date) ti.run(force=True, ignore_dependencies=True, test_mode=True)
def test_requeue_over_concurrency(self, mock_concurrency_reached): mock_concurrency_reached.return_value = True dag = DAG(dag_id='test_requeue_over_concurrency', start_date=DEFAULT_DATE, max_active_runs=1, concurrency=2) task = DummyOperator(task_id='test_requeue_over_concurrency_op', dag=dag) ti = TI(task=task, execution_date=datetime.datetime.now()) ti.run() self.assertEqual(ti.state, models.State.NONE)
def test_check_and_change_state_before_execution(self): dag = models.DAG(dag_id='test_check_and_change_state_before_execution') task = DummyOperator(task_id='task', dag=dag, start_date=DEFAULT_DATE) ti = TI( task=task, execution_date=timezone.utcnow()) self.assertEqual(ti._try_number, 0) self.assertTrue(ti._check_and_change_state_before_execution()) # State should be running, and try_number column should be incremented self.assertEqual(ti.state, State.RUNNING) self.assertEqual(ti._try_number, 1)
def test_set_duration(self): task = DummyOperator(task_id='op', email='*****@*****.**') ti = TI( task=task, execution_date=datetime.datetime.now(), ) ti.start_date = datetime.datetime(2018, 10, 1, 1) ti.end_date = datetime.datetime(2018, 10, 1, 2) ti.set_duration() self.assertEqual(ti.duration, 3600)
def test_render_template(self): ti = TaskInstance(self.mock_operator, DEFAULT_DATE) ti.render_templates() expected_rendered_template = {'$lt': '2017-01-01T00:00:00+00:00Z'} self.assertDictEqual( expected_rendered_template, getattr(self.mock_operator, 'mongo_query') )
def test_retry_handling(self, mock_pool_full): """ Test that task retries are handled properly """ # Mock the pool with a pool with slots open since the pool doesn't actually exist mock_pool_full.return_value = False dag = models.DAG(dag_id='test_retry_handling') task = BashOperator( task_id='test_retry_handling_op', bash_command='exit 1', retries=1, retry_delay=datetime.timedelta(seconds=0), dag=dag, owner='airflow', start_date=timezone.datetime(2016, 2, 1, 0, 0, 0)) def run_with_error(ti): try: ti.run() except AirflowException: pass ti = TI( task=task, execution_date=timezone.utcnow()) self.assertEqual(ti.try_number, 1) # first run -- up for retry run_with_error(ti) self.assertEqual(ti.state, State.UP_FOR_RETRY) self.assertEqual(ti._try_number, 1) self.assertEqual(ti.try_number, 2) # second run -- fail run_with_error(ti) self.assertEqual(ti.state, State.FAILED) self.assertEqual(ti._try_number, 2) self.assertEqual(ti.try_number, 3) # Clear the TI state since you can't run a task with a FAILED state without # clearing it first dag.clear() # third run -- up for retry run_with_error(ti) self.assertEqual(ti.state, State.UP_FOR_RETRY) self.assertEqual(ti._try_number, 3) self.assertEqual(ti.try_number, 4) # fourth run -- fail run_with_error(ti) ti.refresh_from_db() self.assertEqual(ti.state, State.FAILED) self.assertEqual(ti._try_number, 4) self.assertEqual(ti.try_number, 5)
def task_state(args): """ Returns the state of a TaskInstance at the command line. >>> airflow task_state tutorial sleep 2015-01-01 success """ dag = get_dag(args) task = dag.get_task(task_id=args.task_id) ti = TaskInstance(task, args.execution_date) print(ti.current_state())
def test_operation_get_with_templates(self, _): dag_id = 'test_dag_id' configuration.load_test_config() args = {'start_date': DEFAULT_DATE} self.dag = DAG(dag_id, default_args=args) op = GcpTransferServiceOperationGetOperator( operation_name='{{ dag.dag_id }}', task_id='task-id', dag=self.dag ) ti = TaskInstance(op, DEFAULT_DATE) ti.render_templates() self.assertEqual(dag_id, getattr(op, 'operation_name'))
def render(args): dag = get_dag(args) task = dag.get_task(task_id=args.task_id) ti = TaskInstance(task, args.execution_date) ti.render_templates() for attr in task.__class__.template_fields: print(textwrap.dedent("""\ # ---------------------------------------------------------- # property: {} # ---------------------------------------------------------- {} """.format(attr, getattr(task, attr))))
def delete_remote_resource(self): # check the remote file content remove_file_task = SSHOperator( task_id="test_check_file", ssh_hook=self.hook, command="rm {0}".format(self.test_remote_filepath), do_xcom_push=True, dag=self.dag ) self.assertIsNotNone(remove_file_task) ti3 = TaskInstance(task=remove_file_task, execution_date=datetime.now()) ti3.run()
def execute(self, context): # If the DAG Run is externally triggered, then return without # skipping downstream tasks if context['dag_run'] and context['dag_run'].external_trigger: logging.info("""Externally triggered DAG_Run: allowing execution to proceed.""") return now = datetime.datetime.now() left_window = context['dag'].following_schedule( context['execution_date']) right_window = context['dag'].following_schedule(left_window) logging.info( 'Checking latest only with left_window: %s right_window: %s ' 'now: %s', left_window, right_window, now) if not left_window < now <= right_window: logging.info('Not latest execution, skipping downstream.') session = settings.Session() TI = TaskInstance tis = session.query(TI).filter( TI.execution_date == context['ti'].execution_date, TI.task_id.in_(context['task'].downstream_task_ids) ).with_for_update().all() for ti in tis: logging.info('Skipping task: %s', ti.task_id) ti.state = State.SKIPPED ti.start_date = now ti.end_date = now session.merge(ti) # this is defensive against dag runs that are not complete for task in context['task'].downstream_list: if task.task_id in tis: continue logging.warning("Task {} was not part of a dag run. " "This should not happen." .format(task)) now = datetime.datetime.now() ti = TaskInstance(task, execution_date=context['ti'].execution_date) ti.state = State.SKIPPED ti.start_date = now ti.end_date = now session.merge(ti) session.commit() session.close() logging.info('Done.') else: logging.info('Latest, allowing execution to proceed.')
def test_file_task_handler(self): def task_callable(ti, **kwargs): ti.log.info("test") dag = DAG('dag_for_testing_file_task_handler', start_date=DEFAULT_DATE) task = PythonOperator( task_id='task_for_testing_file_log_handler', dag=dag, python_callable=task_callable, provide_context=True ) ti = TaskInstance(task=task, execution_date=DEFAULT_DATE) logger = ti.log ti.log.disabled = False file_handler = next((handler for handler in logger.handlers if handler.name == FILE_TASK_HANDLER), None) self.assertIsNotNone(file_handler) set_context(logger, ti) self.assertIsNotNone(file_handler.handler) # We expect set_context generates a file locally. log_filename = file_handler.handler.baseFilename self.assertTrue(os.path.isfile(log_filename)) self.assertTrue(log_filename.endswith("1.log"), log_filename) ti.run(ignore_ti_state=True) file_handler.flush() file_handler.close() self.assertTrue(hasattr(file_handler, 'read')) # Return value of read must be a tuple of list and list. logs, metadatas = file_handler.read(ti) self.assertTrue(isinstance(logs, list)) self.assertTrue(isinstance(metadatas, list)) self.assertEqual(len(logs), 1) self.assertEqual(len(logs), len(metadatas)) self.assertTrue(isinstance(metadatas[0], dict)) target_re = r'\n\[[^\]]+\] {test_log_handlers.py:\d+} INFO - test\n' # We should expect our log line from the callable above to appear in # the logs we read back six.assertRegex( self, logs[0], target_re, "Logs were " + str(logs) ) # Remove the generated tmp log file. os.remove(log_filename)
def test_xcom_pull_different_execution_date(self): """ tests xcom fetch behavior with different execution dates, using both xcom_pull with "include_prior_dates" and without """ key = 'xcom_key' value = 'xcom_value' dag = models.DAG(dag_id='test_xcom', schedule_interval='@monthly') task = DummyOperator(task_id='test_xcom', dag=dag, pool='test_xcom', owner='airflow', start_date=datetime.datetime(2016, 6, 2, 0, 0, 0)) exec_date = datetime.datetime.now() ti = TI(task=task, execution_date=exec_date) ti.run(mark_success=True) ti.xcom_push(key=key, value=value) self.assertEqual(ti.xcom_pull(task_ids='test_xcom', key=key), value) ti.run() exec_date += datetime.timedelta(days=1) ti = TI(task=task, execution_date=exec_date) ti.run() # We have set a new execution date (and did not pass in # 'include_prior_dates'which means this task should now have a cleared # xcom value self.assertEqual(ti.xcom_pull(task_ids='test_xcom', key=key), None) # We *should* get a value using 'include_prior_dates' self.assertEqual( ti.xcom_pull(task_ids='test_xcom', key=key, include_prior_dates=True), value)
# make sure there isn't already a file where this test will be writing to if os.path.exists(dir + "/" + filename): os.remove(dir + "/" + filename) # init the dag for the tests dag = DAG(dag_id='anydag', start_date=datetime.now()) # init the Operator and Task Instance task = DownloadFileOperator(task_id="get_data", file_url=file_url, dir=dir, filename=filename, dag=dag) ti = TaskInstance(task=task, execution_date=datetime.now()) def test_get_data_from_http_request(): # checks if the operator makes a simple http get correctly assert task.get_data_from_http_request(ti).text == file_url_content def test_execute(): # checks all keys are present in operator output output = task.execute(ti.get_template_context()) keys = output.keys() assert "path" in keys assert "last_modified" in keys assert "checksum" in keys
def get_xcom_value(task_instance: TaskInstance): return task_instance.xcom_pull(task_ids=task_instance.task_id)
def create_task_instance( task: BaseOperator, execution_date: pendulum.datetime = DEFAULT_EXECUTION_DATE, ) -> TaskInstance: return TaskInstance(task=task, execution_date=execution_date)
def dag_backfill(args, dag=None): """Creates backfill job or dry run for a DAG""" logging.basicConfig(level=settings.LOGGING_LEVEL, format=settings.SIMPLE_LOG_FORMAT) signal.signal(signal.SIGTERM, sigint_handler) import warnings warnings.warn( '--ignore-first-depends-on-past is deprecated as the value is always set to True', category=PendingDeprecationWarning) if args.ignore_first_depends_on_past is False: args.ignore_first_depends_on_past = True dag = dag or get_dag(args.subdir, args.dag_id) if not args.start_date and not args.end_date: raise AirflowException("Provide a start_date and/or end_date") # If only one date is passed, using same as start and end args.end_date = args.end_date or args.start_date args.start_date = args.start_date or args.end_date if args.task_regex: dag = dag.sub_dag(task_regex=args.task_regex, include_upstream=not args.ignore_dependencies) run_conf = None if args.conf: run_conf = json.loads(args.conf) if args.dry_run: print("Dry run of DAG {0} on {1}".format(args.dag_id, args.start_date)) for task in dag.tasks: print("Task {0}".format(task.task_id)) ti = TaskInstance(task, args.start_date) ti.dry_run() else: if args.reset_dagruns: DAG.clear_dags( [dag], start_date=args.start_date, end_date=args.end_date, confirm_prompt=not args.yes, include_subdags=True, ) dag.run(start_date=args.start_date, end_date=args.end_date, mark_success=args.mark_success, local=args.local, donot_pickle=(args.donot_pickle or conf.getboolean('core', 'donot_pickle')), ignore_first_depends_on_past=args.ignore_first_depends_on_past, ignore_task_deps=args.ignore_dependencies, pool=args.pool, delay_on_limit_secs=args.delay_on_limit, verbose=args.verbose, conf=run_conf, rerun_failed_tasks=args.rerun_failed_tasks, run_backwards=args.run_backwards)
def _get_task_instance(self, state): dag = DAG('test_dag') task = Mock(dag=dag) ti = TaskInstance(task=task, state=state, execution_date=None) return ti
def test_update_counters(self): dag = DAG(dag_id='test_manage_executor_state', start_date=DEFAULT_DATE) task1 = DummyOperator(task_id='dummy', dag=dag, owner='airflow') job = BackfillJob(dag=dag) session = settings.Session() dr = dag.create_dagrun(run_id=DagRunType.SCHEDULED.value, state=State.RUNNING, execution_date=DEFAULT_DATE, start_date=DEFAULT_DATE, session=session) ti = TI(task1, dr.execution_date) ti.refresh_from_db() ti_status = BackfillJob._DagRunTaskStatus() # test for success ti.set_state(State.SUCCESS, session) ti_status.running[ti.key] = ti job._update_counters(ti_status=ti_status) self.assertTrue(len(ti_status.running) == 0) self.assertTrue(len(ti_status.succeeded) == 1) self.assertTrue(len(ti_status.skipped) == 0) self.assertTrue(len(ti_status.failed) == 0) self.assertTrue(len(ti_status.to_run) == 0) ti_status.succeeded.clear() # test for skipped ti.set_state(State.SKIPPED, session) ti_status.running[ti.key] = ti job._update_counters(ti_status=ti_status) self.assertTrue(len(ti_status.running) == 0) self.assertTrue(len(ti_status.succeeded) == 0) self.assertTrue(len(ti_status.skipped) == 1) self.assertTrue(len(ti_status.failed) == 0) self.assertTrue(len(ti_status.to_run) == 0) ti_status.skipped.clear() # test for failed ti.set_state(State.FAILED, session) ti_status.running[ti.key] = ti job._update_counters(ti_status=ti_status) self.assertTrue(len(ti_status.running) == 0) self.assertTrue(len(ti_status.succeeded) == 0) self.assertTrue(len(ti_status.skipped) == 0) self.assertTrue(len(ti_status.failed) == 1) self.assertTrue(len(ti_status.to_run) == 0) ti_status.failed.clear() # test for retry ti.set_state(State.UP_FOR_RETRY, session) ti_status.running[ti.key] = ti job._update_counters(ti_status=ti_status) self.assertTrue(len(ti_status.running) == 0) self.assertTrue(len(ti_status.succeeded) == 0) self.assertTrue(len(ti_status.skipped) == 0) self.assertTrue(len(ti_status.failed) == 0) self.assertTrue(len(ti_status.to_run) == 1) ti_status.to_run.clear() # test for reschedule ti.set_state(State.UP_FOR_RESCHEDULE, session) ti_status.running[ti.key] = ti job._update_counters(ti_status=ti_status) self.assertTrue(len(ti_status.running) == 0) self.assertTrue(len(ti_status.succeeded) == 0) self.assertTrue(len(ti_status.skipped) == 0) self.assertTrue(len(ti_status.failed) == 0) self.assertTrue(len(ti_status.to_run) == 1) ti_status.to_run.clear() # test for none ti.set_state(State.NONE, session) ti_status.running[ti.key] = ti job._update_counters(ti_status=ti_status) self.assertTrue(len(ti_status.running) == 0) self.assertTrue(len(ti_status.succeeded) == 0) self.assertTrue(len(ti_status.skipped) == 0) self.assertTrue(len(ti_status.failed) == 0) self.assertTrue(len(ti_status.to_run) == 1) ti_status.to_run.clear() session.close()
def test_create_interval_metrics(): ti = TaskInstance(task=create_interval_metrics, execution_date=datetime.now()) result = create_interval_metrics.execute(ti.get_template_context()) print("---") assert result == "succeeded"
def run(args): utils.pessimistic_connection_handling() # Setting up logging log = os.path.expanduser(conf.get('core', 'BASE_LOG_FOLDER')) directory = log + "/{args.dag_id}/{args.task_id}".format(args=args) if not os.path.exists(directory): os.makedirs(directory) args.execution_date = dateutil.parser.parse(args.execution_date) iso = args.execution_date.isoformat() filename = "{directory}/{iso}".format(**locals()) # store old log (to help with S3 appends) if os.path.exists(filename): with open(filename, 'r') as logfile: old_log = logfile.read() else: old_log = None subdir = None if args.subdir: subdir = args.subdir.replace("DAGS_FOLDER", conf.get("core", "DAGS_FOLDER")) subdir = os.path.expanduser(subdir) logging.basicConfig(filename=filename, level=settings.LOGGING_LEVEL, format=settings.LOG_FORMAT) if not args.pickle: dagbag = DagBag(subdir) if args.dag_id not in dagbag.dags: msg = 'DAG [{0}] could not be found'.format(args.dag_id) logging.error(msg) raise AirflowException(msg) dag = dagbag.dags[args.dag_id] task = dag.get_task(task_id=args.task_id) else: session = settings.Session() logging.info('Loading pickle id {args.pickle}'.format(**locals())) dag_pickle = session.query(DagPickle).filter( DagPickle.id == args.pickle).first() if not dag_pickle: raise AirflowException("Who hid the pickle!? [missing pickle]") dag = dag_pickle.pickle task = dag.get_task(task_id=args.task_id) task_start_date = None if args.task_start_date: task_start_date = dateutil.parser.parse(args.task_start_date) task.start_date = task_start_date ti = TaskInstance(task, args.execution_date) if args.local: print("Logging into: " + filename) run_job = jobs.LocalTaskJob( task_instance=ti, mark_success=args.mark_success, force=args.force, pickle_id=args.pickle, task_start_date=task_start_date, ignore_dependencies=args.ignore_dependencies) run_job.run() elif args.raw: ti.run( mark_success=args.mark_success, force=args.force, ignore_dependencies=args.ignore_dependencies, job_id=args.job_id, ) else: pickle_id = None if args.ship_dag: try: # Running remotely, so pickling the DAG session = settings.Session() pickle = DagPickle(dag) session.add(pickle) session.commit() pickle_id = pickle.id print(('Pickled dag {dag} ' 'as pickle_id:{pickle_id}').format(**locals())) except Exception as e: print('Could not pickle the DAG') print(e) raise e executor = DEFAULT_EXECUTOR executor.start() print("Sending to executor.") executor.queue_task_instance( ti, mark_success=args.mark_success, pickle_id=pickle_id, ignore_dependencies=args.ignore_dependencies, force=args.force) executor.heartbeat() executor.end() if conf.get('core', 'S3_LOG_FOLDER').startswith('s3:'): import boto s3_log = filename.replace(log, conf.get('core', 'S3_LOG_FOLDER')) bucket, key = s3_log.lstrip('s3:/').split('/', 1) if os.path.exists(filename): # get logs with open(filename, 'r') as logfile: new_log = logfile.read() # remove old logs (since they are already in S3) if old_log: new_log.replace(old_log, '') try: s3 = boto.connect_s3() s3_key = boto.s3.key.Key(s3.get_bucket(bucket), key) # append new logs to old S3 logs, if available if s3_key.exists(): old_s3_log = s3_key.get_contents_as_string().decode() new_log = old_s3_log + '\n' + new_log # send log to S3 s3_key.set_contents_from_string(new_log) except: print('Could not send logs to S3.')
def setUp(self): dag = DAG('dag_for_testing_filename_rendering', start_date=DEFAULT_DATE) task = DummyOperator(task_id='task_for_testing_filename_rendering', dag=dag) self.ti = TaskInstance(task=task, execution_date=DEFAULT_DATE)
def _get_task_instance(self, execution_date, dag_end_date=None, task_end_date=None): dag = Mock(end_date=dag_end_date) task = Mock(dag=dag, end_date=task_end_date) return TaskInstance(task=task, execution_date=execution_date)
def test_next_retry_datetime(self): delay = datetime.timedelta(seconds=30) max_delay = datetime.timedelta(minutes=60) dag = models.DAG(dag_id='fail_dag') task = BashOperator( task_id='task_with_exp_backoff_and_max_delay', bash_command='exit 1', retries=3, retry_delay=delay, retry_exponential_backoff=True, max_retry_delay=max_delay, dag=dag, owner='airflow', start_date=timezone.datetime(2016, 2, 1, 0, 0, 0)) ti = TI( task=task, execution_date=DEFAULT_DATE) ti.end_date = pendulum.instance(timezone.utcnow()) dt = ti.next_retry_datetime() # between 30 * 2^0.5 and 30 * 2^1 (15 and 30) period = ti.end_date.add(seconds=30) - ti.end_date.add(seconds=15) self.assertTrue(dt in period) ti.try_number = 3 dt = ti.next_retry_datetime() # between 30 * 2^2 and 30 * 2^3 (120 and 240) period = ti.end_date.add(seconds=240) - ti.end_date.add(seconds=120) self.assertTrue(dt in period) ti.try_number = 5 dt = ti.next_retry_datetime() # between 30 * 2^4 and 30 * 2^5 (480 and 960) period = ti.end_date.add(seconds=960) - ti.end_date.add(seconds=480) self.assertTrue(dt in period) ti.try_number = 9 dt = ti.next_retry_datetime() self.assertEqual(dt, ti.end_date + max_delay) ti.try_number = 50 dt = ti.next_retry_datetime() self.assertEqual(dt, ti.end_date + max_delay)
def test_set_duration_empty_dates(self): task = DummyOperator(task_id='op', email='*****@*****.**') ti = TI(task=task, execution_date=datetime.datetime.now()) ti.set_duration() self.assertIsNone(ti.duration)
def test_run_ignores_all_dependencies(self): """ Test that run respects ignore_all_dependencies """ dag_id = 'test_run_ignores_all_dependencies' dag = self.dagbag.get_dag('test_run_ignores_all_dependencies') dag.clear() task0_id = 'test_run_dependent_task' args0 = [ 'tasks', 'run', '--ignore-all-dependencies', dag_id, task0_id, DEFAULT_DATE.isoformat() ] task_command.task_run(self.parser.parse_args(args0)) ti_dependent0 = TaskInstance(task=dag.get_task(task0_id), execution_date=DEFAULT_DATE) ti_dependent0.refresh_from_db() assert ti_dependent0.state == State.FAILED task1_id = 'test_run_dependency_task' args1 = [ 'tasks', 'run', '--ignore-all-dependencies', dag_id, task1_id, (DEFAULT_DATE + timedelta(days=1)).isoformat(), ] task_command.task_run(self.parser.parse_args(args1)) ti_dependency = TaskInstance(task=dag.get_task(task1_id), execution_date=DEFAULT_DATE + timedelta(days=1)) ti_dependency.refresh_from_db() assert ti_dependency.state == State.FAILED task2_id = 'test_run_dependent_task' args2 = [ 'tasks', 'run', '--ignore-all-dependencies', dag_id, task2_id, (DEFAULT_DATE + timedelta(days=1)).isoformat(), ] task_command.task_run(self.parser.parse_args(args2)) ti_dependent = TaskInstance(task=dag.get_task(task2_id), execution_date=DEFAULT_DATE + timedelta(days=1)) ti_dependent.refresh_from_db() assert ti_dependent.state == State.SUCCESS
def test_zombies_are_correctly_passed_to_dag_file_processor(self): """ Check that the same set of zombies are passed to the dag file processors until the next zombie detection logic is invoked. """ with conf_vars({ ('scheduler', 'max_threads'): '1', ('core', 'load_examples'): 'False' }): dagbag = DagBag( os.path.join(TEST_DAG_FOLDER, 'test_example_bash_operator.py')) with create_session() as session: session.query(LJ).delete() dag = dagbag.get_dag('test_example_bash_operator') task = dag.get_task(task_id='run_this_last') ti = TI(task, DEFAULT_DATE, State.RUNNING) lj = LJ(ti) lj.state = State.SHUTDOWN lj.id = 1 ti.job_id = lj.id session.add(lj) session.add(ti) session.commit() fake_zombies = [SimpleTaskInstance(ti)] class FakeDagFIleProcessor(DagFileProcessor): # This fake processor will return the zombies it received in constructor # as its processing result w/o actually parsing anything. def __init__(self, file_path, pickle_dags, dag_id_white_list, zombies): super(FakeDagFIleProcessor, self).__init__(file_path, pickle_dags, dag_id_white_list, zombies) self._result = zombies, 0 def start(self): pass @property def start_time(self): return DEFAULT_DATE @property def pid(self): return 1234 @property def done(self): return True @property def result(self): return self._result def processor_factory(file_path, zombies): return FakeDagFIleProcessor(file_path, False, [], zombies) test_dag_path = os.path.join(TEST_DAG_FOLDER, 'test_example_bash_operator.py') async_mode = 'sqlite' not in conf.get('core', 'sql_alchemy_conn') processor_agent = DagFileProcessorAgent(test_dag_path, [], 1, processor_factory, timedelta.max, async_mode) processor_agent.start() parsing_result = [] if not async_mode: processor_agent.heartbeat() while not processor_agent.done: if not async_mode: processor_agent.wait_until_finished() parsing_result.extend(processor_agent.harvest_simple_dags()) self.assertEqual(len(fake_zombies), len(parsing_result)) self.assertEqual(set([zombie.key for zombie in fake_zombies]), set([result.key for result in parsing_result]))
def test_handle_failure_callback_with_zombies_are_correctly_passed_to_dag_file_processor( self): """ Check that the same set of failure callback with zombies are passed to the dag file processors until the next zombie detection logic is invoked. """ test_dag_path = os.path.join(TEST_DAG_FOLDER, 'test_example_bash_operator.py') with conf_vars({ ('scheduler', 'parsing_processes'): '1', ('core', 'load_examples'): 'False' }): dagbag = DagBag(test_dag_path, read_dags_from_db=False) with create_session() as session: session.query(LJ).delete() dag = dagbag.get_dag('test_example_bash_operator') dag.sync_to_db() task = dag.get_task(task_id='run_this_last') ti = TI(task, DEFAULT_DATE, State.RUNNING) local_job = LJ(ti) local_job.state = State.SHUTDOWN session.add(local_job) session.commit() # TODO: If there was an actual Relationship between TI and Job # we wouldn't need this extra commit session.add(ti) ti.job_id = local_job.id session.commit() expected_failure_callback_requests = [ TaskCallbackRequest( full_filepath=dag.full_filepath, simple_task_instance=SimpleTaskInstance(ti), msg="Message", ) ] test_dag_path = os.path.join(TEST_DAG_FOLDER, 'test_example_bash_operator.py') child_pipe, parent_pipe = multiprocessing.Pipe() async_mode = 'sqlite' not in conf.get('core', 'sql_alchemy_conn') fake_processors = [] def fake_processor_factory(*args, **kwargs): nonlocal fake_processors processor = FakeDagFileProcessorRunner._fake_dag_processor_factory( *args, **kwargs) fake_processors.append(processor) return processor manager = DagFileProcessorManager( dag_directory=test_dag_path, max_runs=1, processor_factory=fake_processor_factory, processor_timeout=timedelta.max, signal_conn=child_pipe, dag_ids=[], pickle_dags=False, async_mode=async_mode, ) self.run_processor_manager_one_loop(manager, parent_pipe) if async_mode: # Once for initial parse, and then again for the add_callback_to_queue assert len(fake_processors) == 2 assert fake_processors[0]._file_path == test_dag_path assert fake_processors[0]._callback_requests == [] else: assert len(fake_processors) == 1 assert fake_processors[-1]._file_path == test_dag_path callback_requests = fake_processors[-1]._callback_requests assert { zombie.simple_task_instance.key for zombie in expected_failure_callback_requests } == { result.simple_task_instance.key for result in callback_requests } child_pipe.close() parent_pipe.close()
def run(args): utils.pessimistic_connection_handling() # Setting up logging log_base = os.path.expanduser(conf.get('core', 'BASE_LOG_FOLDER')) directory = log_base + "/{args.dag_id}/{args.task_id}".format(args=args) if not os.path.exists(directory): os.makedirs(directory) args.execution_date = dateutil.parser.parse(args.execution_date) iso = args.execution_date.isoformat() filename = "{directory}/{iso}".format(**locals()) subdir = process_subdir(args.subdir) logging.root.handlers = [] logging.basicConfig(filename=filename, level=settings.LOGGING_LEVEL, format=settings.LOG_FORMAT) if not args.pickle: dagbag = DagBag(subdir) if args.dag_id not in dagbag.dags: msg = 'DAG [{0}] could not be found in {1}'.format( args.dag_id, subdir) logging.error(msg) raise AirflowException(msg) dag = dagbag.dags[args.dag_id] task = dag.get_task(task_id=args.task_id) else: session = settings.Session() logging.info('Loading pickle id {args.pickle}'.format(**locals())) dag_pickle = session.query(DagPickle).filter( DagPickle.id == args.pickle).first() if not dag_pickle: raise AirflowException("Who hid the pickle!? [missing pickle]") dag = dag_pickle.pickle task = dag.get_task(task_id=args.task_id) task_start_date = None if args.task_start_date: task_start_date = dateutil.parser.parse(args.task_start_date) task.start_date = task_start_date ti = TaskInstance(task, args.execution_date) if args.local: print("Logging into: " + filename) run_job = jobs.LocalTaskJob( task_instance=ti, mark_success=args.mark_success, force=args.force, pickle_id=args.pickle, task_start_date=task_start_date, ignore_dependencies=args.ignore_dependencies, pool=args.pool) run_job.run() elif args.raw: ti.run( mark_success=args.mark_success, force=args.force, ignore_dependencies=args.ignore_dependencies, job_id=args.job_id, pool=args.pool, ) else: pickle_id = None if args.ship_dag: try: # Running remotely, so pickling the DAG session = settings.Session() pickle = DagPickle(dag) session.add(pickle) session.commit() pickle_id = pickle.id print(('Pickled dag {dag} ' 'as pickle_id:{pickle_id}').format(**locals())) except Exception as e: print('Could not pickle the DAG') print(e) raise e executor = DEFAULT_EXECUTOR executor.start() print("Sending to executor.") executor.queue_task_instance( ti, mark_success=args.mark_success, pickle_id=pickle_id, ignore_dependencies=args.ignore_dependencies, force=args.force, pool=args.pool) executor.heartbeat() executor.end() # store logs remotely remote_base = conf.get('core', 'REMOTE_BASE_LOG_FOLDER') # deprecated as of March 2016 if not remote_base and conf.get('core', 'S3_LOG_FOLDER'): warnings.warn( 'The S3_LOG_FOLDER conf key has been replaced by ' 'REMOTE_BASE_LOG_FOLDER. Your conf still works but please ' 'update airflow.cfg to ensure future compatibility.', DeprecationWarning) remote_base = conf.get('core', 'S3_LOG_FOLDER') if os.path.exists(filename): # read log and remove old logs to get just the latest additions with open(filename, 'r') as logfile: log = logfile.read() remote_log_location = filename.replace(log_base, remote_base) # S3 if remote_base.startswith('s3:/'): utils.S3Log().write(log, remote_log_location) # GCS elif remote_base.startswith('gs:/'): utils.GCSLog().write(log, remote_log_location, append=True) # Other elif remote_base: logging.error( 'Unsupported remote log location: {}'.format(remote_base))
def test_get_states_count_upstream_ti(self): """ this test tests the helper function '_get_states_count_upstream_ti' as a unit and inside update_state """ from airflow.ti_deps.dep_context import DepContext get_states_count_upstream_ti = TriggerRuleDep._get_states_count_upstream_ti session = settings.Session() now = timezone.utcnow() dag = DAG( 'test_dagrun_with_pre_tis', start_date=DEFAULT_DATE, default_args={'owner': 'owner1'}) with dag: op1 = DummyOperator(task_id='A') op2 = DummyOperator(task_id='B') op3 = DummyOperator(task_id='C') op4 = DummyOperator(task_id='D') op5 = DummyOperator(task_id='E', trigger_rule=TriggerRule.ONE_FAILED) op1.set_downstream([op2, op3]) # op1 >> op2, op3 op4.set_upstream([op3, op2]) # op3, op2 >> op4 op5.set_upstream([op2, op3, op4]) # (op2, op3, op4) >> op5 clear_db_runs() dag.clear() dr = dag.create_dagrun(run_id='test_dagrun_with_pre_tis', state=State.RUNNING, execution_date=now, start_date=now) ti_op1 = TaskInstance(task=dag.get_task(op1.task_id), execution_date=dr.execution_date) ti_op2 = TaskInstance(task=dag.get_task(op2.task_id), execution_date=dr.execution_date) ti_op3 = TaskInstance(task=dag.get_task(op3.task_id), execution_date=dr.execution_date) ti_op4 = TaskInstance(task=dag.get_task(op4.task_id), execution_date=dr.execution_date) ti_op5 = TaskInstance(task=dag.get_task(op5.task_id), execution_date=dr.execution_date) ti_op1.set_state(state=State.SUCCESS, session=session) ti_op2.set_state(state=State.FAILED, session=session) ti_op3.set_state(state=State.SUCCESS, session=session) ti_op4.set_state(state=State.SUCCESS, session=session) ti_op5.set_state(state=State.SUCCESS, session=session) session.commit() # check handling with cases that tasks are triggered from backfill with no finished tasks finished_tasks = DepContext().ensure_finished_tasks(ti_op2.task.dag, ti_op2.execution_date, session) self.assertEqual(get_states_count_upstream_ti(finished_tasks=finished_tasks, ti=ti_op2), (1, 0, 0, 0, 1)) finished_tasks = dr.get_task_instances(state=State.finished() + [State.UPSTREAM_FAILED], session=session) self.assertEqual(get_states_count_upstream_ti(finished_tasks=finished_tasks, ti=ti_op4), (1, 0, 1, 0, 2)) self.assertEqual(get_states_count_upstream_ti(finished_tasks=finished_tasks, ti=ti_op5), (2, 0, 1, 0, 3)) dr.update_state() self.assertEqual(State.SUCCESS, dr.state)
def run_task(task_instance: TaskInstance) -> State: task_instance._run_raw_task(test_mode=True) return task_instance.state
def run(args, dag=None): db_utils.pessimistic_connection_handling() if dag: args.dag_id = dag.dag_id # Setting up logging log_base = os.path.expanduser(conf.get('core', 'BASE_LOG_FOLDER')) directory = log_base + "/{args.dag_id}/{args.task_id}".format(args=args) if not os.path.exists(directory): os.makedirs(directory) iso = args.execution_date.isoformat() filename = "{directory}/{iso}".format(**locals()) logging.root.handlers = [] logging.basicConfig( filename=filename, level=settings.LOGGING_LEVEL, format=settings.LOG_FORMAT) if not args.pickle and not dag: dag = get_dag(args) elif not dag: session = settings.Session() logging.info('Loading pickle id {args.pickle}'.format(**locals())) dag_pickle = session.query( DagPickle).filter(DagPickle.id == args.pickle).first() if not dag_pickle: raise AirflowException("Who hid the pickle!? [missing pickle]") dag = dag_pickle.pickle task = dag.get_task(task_id=args.task_id) ti = TaskInstance(task, args.execution_date) if args.local: print("Logging into: " + filename) run_job = jobs.LocalTaskJob( task_instance=ti, mark_success=args.mark_success, pickle_id=args.pickle, ignore_all_deps=args.ignore_all_dependencies, ignore_depends_on_past=args.ignore_depends_on_past, ignore_task_deps=args.ignore_dependencies, ignore_ti_state=args.force, pool=args.pool) run_job.run() elif args.raw: ti.run( mark_success=args.mark_success, ignore_all_deps=args.ignore_all_dependencies, ignore_depends_on_past=args.ignore_depends_on_past, ignore_task_deps=args.ignore_dependencies, ignore_ti_state=args.force, job_id=args.job_id, pool=args.pool, ) else: pickle_id = None if args.ship_dag: try: # Running remotely, so pickling the DAG session = settings.Session() pickle = DagPickle(dag) session.add(pickle) session.commit() pickle_id = pickle.id print(( 'Pickled dag {dag} ' 'as pickle_id:{pickle_id}').format(**locals())) except Exception as e: print('Could not pickle the DAG') print(e) raise e executor = DEFAULT_EXECUTOR executor.start() print("Sending to executor.") executor.queue_task_instance( ti, mark_success=args.mark_success, pickle_id=pickle_id, ignore_all_deps=args.ignore_all_dependencies, ignore_depends_on_past=args.ignore_depends_on_past, ignore_task_deps=args.ignore_dependencies, ignore_ti_state=args.force, pool=args.pool) executor.heartbeat() executor.end() # Force the log to flush, and set the handler to go back to normal so we # don't continue logging to the task's log file. The flush is important # because we subsequently read from the log to insert into S3 or Google # cloud storage. logging.root.handlers[0].flush() logging.root.handlers = [] # store logs remotely remote_base = conf.get('core', 'REMOTE_BASE_LOG_FOLDER') # deprecated as of March 2016 if not remote_base and conf.get('core', 'S3_LOG_FOLDER'): warnings.warn( 'The S3_LOG_FOLDER conf key has been replaced by ' 'REMOTE_BASE_LOG_FOLDER. Your conf still works but please ' 'update airflow.cfg to ensure future compatibility.', DeprecationWarning) remote_base = conf.get('core', 'S3_LOG_FOLDER') if os.path.exists(filename): # read log and remove old logs to get just the latest additions with open(filename, 'r') as logfile: log = logfile.read() remote_log_location = filename.replace(log_base, remote_base) # S3 if remote_base.startswith('s3:/'): logging_utils.S3Log().write(log, remote_log_location) # GCS elif remote_base.startswith('gs:/'): logging_utils.GCSLog().write( log, remote_log_location, append=True) # Other elif remote_base and remote_base != 'None': logging.error( 'Unsupported remote log location: {}'.format(remote_base))
def run_dag_tasks(dag): for task in dag.tasks: ti = TaskInstance(task=task, execution_date=DEFAULT_DATE) ti.run(ignore_ti_state=True) assert ti.state == State.SUCCESS
def test_xcom_pull(self): """ Test xcom_pull, using different filtering methods. """ dag = models.DAG( dag_id='test_xcom', schedule_interval='@monthly', start_date=timezone.datetime(2016, 6, 1, 0, 0, 0)) exec_date = timezone.utcnow() # Push a value task1 = DummyOperator(task_id='test_xcom_1', dag=dag, owner='airflow') ti1 = TI(task=task1, execution_date=exec_date) ti1.xcom_push(key='foo', value='bar') # Push another value with the same key (but by a different task) task2 = DummyOperator(task_id='test_xcom_2', dag=dag, owner='airflow') ti2 = TI(task=task2, execution_date=exec_date) ti2.xcom_push(key='foo', value='baz') # Pull with no arguments result = ti1.xcom_pull() self.assertEqual(result, None) # Pull the value pushed most recently by any task. result = ti1.xcom_pull(key='foo') self.assertIn(result, 'baz') # Pull the value pushed by the first task result = ti1.xcom_pull(task_ids='test_xcom_1', key='foo') self.assertEqual(result, 'bar') # Pull the value pushed by the second task result = ti1.xcom_pull(task_ids='test_xcom_2', key='foo') self.assertEqual(result, 'baz') # Pull the values pushed by both tasks result = ti1.xcom_pull( task_ids=['test_xcom_1', 'test_xcom_2'], key='foo') self.assertEqual(result, ('bar', 'baz'))
def test_dag_clear(self): dag = DAG('test_dag_clear', start_date=DEFAULT_DATE, end_date=DEFAULT_DATE + datetime.timedelta(days=10)) task0 = DummyOperator(task_id='test_dag_clear_task_0', owner='test', dag=dag) ti0 = TI(task=task0, execution_date=DEFAULT_DATE) dag.create_dagrun( execution_date=ti0.execution_date, state=State.RUNNING, run_type=DagRunType.SCHEDULED, ) # Next try to run will be try 1 self.assertEqual(ti0.try_number, 1) ti0.run() self.assertEqual(ti0.try_number, 2) dag.clear() ti0.refresh_from_db() self.assertEqual(ti0.try_number, 2) self.assertEqual(ti0.state, State.NONE) self.assertEqual(ti0.max_tries, 1) task1 = DummyOperator(task_id='test_dag_clear_task_1', owner='test', dag=dag, retries=2) ti1 = TI(task=task1, execution_date=DEFAULT_DATE) self.assertEqual(ti1.max_tries, 2) ti1.try_number = 1 # Next try will be 2 ti1.run() self.assertEqual(ti1.try_number, 3) self.assertEqual(ti1.max_tries, 2) dag.clear() ti0.refresh_from_db() ti1.refresh_from_db() # after clear dag, ti2 should show attempt 3 of 5 self.assertEqual(ti1.max_tries, 4) self.assertEqual(ti1.try_number, 3) # after clear dag, ti1 should show attempt 2 of 2 self.assertEqual(ti0.try_number, 2) self.assertEqual(ti0.max_tries, 1)
def test_reschedule_handling(self, mock_pool_full): """ Test that task reschedules are handled properly """ # Return values of the python sensor callable, modified during tests done = False fail = False def callable(): if fail: raise AirflowException() return done dag = models.DAG(dag_id='test_reschedule_handling') task = PythonSensor( task_id='test_reschedule_handling_sensor', poke_interval=0, mode='reschedule', python_callable=callable, retries=1, retry_delay=datetime.timedelta(seconds=0), dag=dag, owner='airflow', pool='test_pool', start_date=timezone.datetime(2016, 2, 1, 0, 0, 0)) ti = TI(task=task, execution_date=timezone.utcnow()) self.assertEqual(ti._try_number, 0) self.assertEqual(ti.try_number, 1) def run_ti_and_assert(run_date, expected_start_date, expected_end_date, expected_duration, expected_state, expected_try_number, expected_task_reschedule_count): with freeze_time(run_date): try: ti.run() except AirflowException: if not fail: raise ti.refresh_from_db() self.assertEqual(ti.state, expected_state) self.assertEqual(ti._try_number, expected_try_number) self.assertEqual(ti.try_number, expected_try_number + 1) self.assertEqual(ti.start_date, expected_start_date) self.assertEqual(ti.end_date, expected_end_date) self.assertEqual(ti.duration, expected_duration) trs = TaskReschedule.find_for_task_instance(ti) self.assertEqual(len(trs), expected_task_reschedule_count) date1 = timezone.utcnow() date2 = date1 + datetime.timedelta(minutes=1) date3 = date2 + datetime.timedelta(minutes=1) date4 = date3 + datetime.timedelta(minutes=1) # Run with multiple reschedules. # During reschedule the try number remains the same, but each reschedule is recorded. # The start date is expected to remain the initial date, hence the duration increases. # When finished the try number is incremented and there is no reschedule expected # for this try. done, fail = False, False run_ti_and_assert(date1, date1, date1, 0, State.UP_FOR_RESCHEDULE, 0, 1) done, fail = False, False run_ti_and_assert(date2, date1, date2, 60, State.UP_FOR_RESCHEDULE, 0, 2) done, fail = False, False run_ti_and_assert(date3, date1, date3, 120, State.UP_FOR_RESCHEDULE, 0, 3) done, fail = True, False run_ti_and_assert(date4, date1, date4, 180, State.SUCCESS, 1, 0) # Clear the task instance. dag.clear() ti.refresh_from_db() self.assertEqual(ti.state, State.NONE) self.assertEqual(ti._try_number, 1) # Run again after clearing with reschedules and a retry. # The retry increments the try number, and for that try no reschedule is expected. # After the retry the start date is reset, hence the duration is also reset. done, fail = False, False run_ti_and_assert(date1, date1, date1, 0, State.UP_FOR_RESCHEDULE, 1, 1) done, fail = False, True run_ti_and_assert(date2, date1, date2, 60, State.UP_FOR_RETRY, 2, 0) done, fail = False, False run_ti_and_assert(date3, date3, date3, 0, State.UP_FOR_RESCHEDULE, 2, 1) done, fail = True, False run_ti_and_assert(date4, date3, date4, 60, State.SUCCESS, 3, 0)
for (word, count) in output: print("%s: %i" % (word, count)) return get_spark_session().createDataFrame(counts) class WordCountPySparkTask(PySparkTask): text = parameter.data counters = parameter.output python_script = relative_path(__file__, "spark_scripts/word_count.py") def application_args(self): return [self.text, self.counters] with DAG(dag_id="dbnd_dag_with_spark", default_args=default_args) as dag_spark: # noinspection PyTypeChecker spark_task = WordCountPySparkTask(text="s3://dbnd/README.md") spark_op = spark_task.op # spark_result = word_count_inline("/tmp/sample.txt") # spark_op = spark_result.op if __name__ == "__main__": ti = TaskInstance(spark_op, days_ago(0)) ti.run(ignore_task_deps=True, ignore_ti_state=True, test_mode=True) # # # # # dag_spark.clear() # dag_spark.run(start_date=days_ago(0), end_date=days_ago(0))
def test_xcom_pull_after_success(self): """ tests xcom set/clear relative to a task in a 'success' rerun scenario """ key = 'xcom_key' value = 'xcom_value' dag = models.DAG(dag_id='test_xcom', schedule_interval='@monthly') task = DummyOperator(task_id='test_xcom', dag=dag, pool='test_xcom', owner='airflow', start_date=datetime.datetime(2016, 6, 2, 0, 0, 0)) exec_date = datetime.datetime.now() ti = TI(task=task, execution_date=exec_date) ti.run(mark_success=True) ti.xcom_push(key=key, value=value) self.assertEqual(ti.xcom_pull(task_ids='test_xcom', key=key), value) ti.run() # The second run and assert is to handle AIRFLOW-131 (don't clear on # prior success) self.assertEqual(ti.xcom_pull(task_ids='test_xcom', key=key), value) # Test AIRFLOW-703: Xcom shouldn't be cleared if the task doesn't # execute, even if dependencies are ignored ti.run(ignore_all_deps=True, mark_success=True) self.assertEqual(ti.xcom_pull(task_ids='test_xcom', key=key), value) # Xcom IS finally cleared once task has executed ti.run(ignore_all_deps=True) self.assertEqual(ti.xcom_pull(task_ids='test_xcom', key=key), None)
def test_extra_serialized_field_and_multiple_operator_links(self): """ Assert extra field exists & OperatorLinks defined in Plugins and inbuilt Operator Links. This tests also depends on GoogleLink() registered as a plugin in tests/plugins/test_plugin.py The function tests that if extra operator links are registered in plugin in ``operator_extra_links`` and the same is also defined in the Operator in ``BaseOperator.operator_extra_links``, it has the correct extra link. """ test_date = datetime(2019, 8, 1) dag = DAG(dag_id='simple_dag', start_date=test_date) CustomOperator(task_id='simple_task', dag=dag, bash_command=["echo", "true"]) serialized_dag = SerializedDAG.to_dict(dag) self.assertIn("bash_command", serialized_dag["dag"]["tasks"][0]) dag = SerializedDAG.from_dict(serialized_dag) simple_task = dag.task_dict["simple_task"] self.assertEqual(getattr(simple_task, "bash_command"), ["echo", "true"]) ######################################################### # Verify Operator Links work with Serialized Operator ######################################################### # Check Serialized version of operator link only contains the inbuilt Op Link self.assertEqual( serialized_dag["dag"]["tasks"][0]["_operator_extra_links"], [ { 'airflow.utils.tests.CustomBaseIndexOpLink': { 'index': 0 } }, { 'airflow.utils.tests.CustomBaseIndexOpLink': { 'index': 1 } }, ]) # Test all the extra_links are set six.assertCountEqual(self, simple_task.extra_links, [ 'BigQuery Console #1', 'BigQuery Console #2', 'airflow', 'github', 'google' ]) ti = TaskInstance(task=simple_task, execution_date=test_date) ti.xcom_push('search_query', ["dummy_value_1", "dummy_value_2"]) # Test Deserialized inbuilt link #1 custom_inbuilt_link = simple_task.get_extra_links( test_date, "BigQuery Console #1") self.assertEqual( 'https://console.cloud.google.com/bigquery?j=dummy_value_1', custom_inbuilt_link) # Test Deserialized inbuilt link #2 custom_inbuilt_link = simple_task.get_extra_links( test_date, "BigQuery Console #2") self.assertEqual( 'https://console.cloud.google.com/bigquery?j=dummy_value_2', custom_inbuilt_link) # Test Deserialized link registered via Airflow Plugin google_link_from_plugin = simple_task.get_extra_links( test_date, GoogleLink.name) self.assertEqual("https://www.google.com", google_link_from_plugin)
def test_subdag_clear_parentdag_downstream_clear(self): dag = self.dagbag.get_dag('clear_subdag_test_dag') subdag_op_task = dag.get_task('daily_job') subdag = subdag_op_task.subdag executor = MockExecutor() job = BackfillJob(dag=dag, start_date=DEFAULT_DATE, end_date=DEFAULT_DATE, executor=executor, donot_pickle=True) with timeout(seconds=30): job.run() ti_subdag = TI(task=dag.get_task('daily_job'), execution_date=DEFAULT_DATE) ti_subdag.refresh_from_db() self.assertEqual(ti_subdag.state, State.SUCCESS) ti_irrelevant = TI(task=dag.get_task('daily_job_irrelevant'), execution_date=DEFAULT_DATE) ti_irrelevant.refresh_from_db() self.assertEqual(ti_irrelevant.state, State.SUCCESS) ti_downstream = TI(task=dag.get_task('daily_job_downstream'), execution_date=DEFAULT_DATE) ti_downstream.refresh_from_db() self.assertEqual(ti_downstream.state, State.SUCCESS) sdag = subdag.sub_dag(task_regex='daily_job_subdag_task', include_downstream=True, include_upstream=False) sdag.clear(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE, include_parentdag=True) ti_subdag.refresh_from_db() self.assertEqual(State.NONE, ti_subdag.state) ti_irrelevant.refresh_from_db() self.assertEqual(State.SUCCESS, ti_irrelevant.state) ti_downstream.refresh_from_db() self.assertEqual(State.NONE, ti_downstream.state) subdag.clear() dag.clear()
def get_link(self, operator, dttm): ti = TaskInstance(task=operator, execution_date=dttm) search_query = ti.xcom_pull(task_ids=operator.task_id, key='search_query') return 'http://google.com/custom_base_link?search={}'.format( search_query)
def test_infer_predictions(): ti = TaskInstance(task=infer_predictions, execution_date=datetime.now()) result = infer_predictions.execute(ti.get_template_context()) assert result == "succeeded"