예제 #1
0
def backfill(args):
    logging.basicConfig(level=settings.LOGGING_LEVEL, format=settings.SIMPLE_LOG_FORMAT)
    dagbag = DagBag(args.subdir)
    if args.dag_id not in dagbag.dags:
        raise AirflowException("dag_id could not be found")
    dag = dagbag.dags[args.dag_id]

    if args.start_date:
        args.start_date = dateutil.parser.parse(args.start_date)
    if args.end_date:
        args.end_date = dateutil.parser.parse(args.end_date)

    # If only one date is passed, using same as start and end
    args.end_date = args.end_date or args.start_date
    args.start_date = args.start_date or args.end_date

    if args.task_regex:
        dag = dag.sub_dag(task_regex=args.task_regex, include_upstream=not args.ignore_dependencies)

    if args.dry_run:
        print("Dry run of DAG {0} on {1}".format(args.dag_id, args.start_date))
        for task in dag.tasks:
            print("Task {0}".format(task.task_id))
            ti = TaskInstance(task, args.start_date)
            ti.dry_run()
    else:
        dag.run(
            start_date=args.start_date,
            end_date=args.end_date,
            mark_success=args.mark_success,
            include_adhoc=args.include_adhoc,
            local=args.local,
            donot_pickle=(args.donot_pickle or conf.getboolean("core", "donot_pickle")),
            ignore_dependencies=args.ignore_dependencies,
        )
    def test_file_task_handler(self):
        dag = DAG('dag_for_testing_file_task_handler', start_date=DEFAULT_DATE)
        task = DummyOperator(task_id='task_for_testing_file_log_handler', dag=dag)
        ti = TaskInstance(task=task, execution_date=DEFAULT_DATE)

        logger = logging.getLogger(TASK_LOGGER)
        file_handler = next((handler for handler in logger.handlers
                             if handler.name == FILE_TASK_HANDLER), None)
        self.assertIsNotNone(file_handler)

        file_handler.set_context(ti)
        self.assertIsNotNone(file_handler.handler)
        # We expect set_context generates a file locally.
        log_filename = file_handler.handler.baseFilename
        self.assertTrue(os.path.isfile(log_filename))

        logger.info("test")
        ti.run()

        self.assertTrue(hasattr(file_handler, 'read'))
        # Return value of read must be a list.
        logs = file_handler.read(ti)
        self.assertTrue(isinstance(logs, list))
        self.assertEqual(len(logs), 1)

        # Remove the generated tmp log file.
        os.remove(log_filename)
예제 #3
0
    def setUp(self):
        super(TestLogView, self).setUp()

        # Create a custom logging configuration
        configuration.load_test_config()
        logging_config = copy.deepcopy(DEFAULT_LOGGING_CONFIG)
        current_dir = os.path.dirname(os.path.abspath(__file__))
        logging_config['handlers']['task']['base_log_folder'] = os.path.normpath(
            os.path.join(current_dir, 'test_logs'))
        logging_config['handlers']['task']['filename_template'] = \
            '{{ ti.dag_id }}/{{ ti.task_id }}/{{ ts | replace(":", ".") }}/{{ try_number }}.log'

        # Write the custom logging configuration to a file
        self.settings_folder = tempfile.mkdtemp()
        settings_file = os.path.join(self.settings_folder, "airflow_local_settings.py")
        new_logging_file = "LOGGING_CONFIG = {}".format(logging_config)
        with open(settings_file, 'w') as handle:
            handle.writelines(new_logging_file)
        sys.path.append(self.settings_folder)
        conf.set('core', 'logging_config_class', 'airflow_local_settings.LOGGING_CONFIG')

        app = application.create_app(testing=True)
        self.app = app.test_client()
        self.session = Session()
        from airflow.www.views import dagbag
        dag = DAG(self.DAG_ID, start_date=self.DEFAULT_DATE)
        task = DummyOperator(task_id=self.TASK_ID, dag=dag)
        dagbag.bag_dag(dag, parent_dag=dag, root_dag=dag)
        ti = TaskInstance(task=task, execution_date=self.DEFAULT_DATE)
        ti.try_number = 1
        self.session.merge(ti)
        self.session.commit()
예제 #4
0
파일: models.py 프로젝트: mtagle/airflow
    def test_check_task_dependencies(
        self,
        trigger_rule,
        successes,
        skipped,
        failed,
        upstream_failed,
        done,
        flag_upstream_failed,
        expect_state,
        expect_completed,
    ):
        start_date = datetime.datetime(2016, 2, 1, 0, 0, 0)
        dag = models.DAG("test-dag", start_date=start_date)
        downstream = DummyOperator(task_id="downstream", dag=dag, owner="airflow", trigger_rule=trigger_rule)
        for i in range(5):
            task = DummyOperator(task_id="runme_{}".format(i), dag=dag, owner="airflow")
            task.set_downstream(downstream)
        run_date = task.start_date + datetime.timedelta(days=5)

        ti = TI(downstream, run_date)
        completed = ti.evaluate_trigger_rule(
            successes=successes,
            skipped=skipped,
            failed=failed,
            upstream_failed=upstream_failed,
            done=done,
            flag_upstream_failed=flag_upstream_failed,
        )

        self.assertEqual(completed, expect_completed)
        self.assertEqual(ti.state, expect_state)
    def test_file_transfer_no_intermediate_dir_error_put(self):
        configuration.conf.set("core", "enable_xcom_pickling", "True")
        test_local_file_content = \
            b"This is local file content \n which is multiline " \
            b"continuing....with other character\nanother line here \n this is last line"
        # create a test file locally
        with open(self.test_local_filepath, 'wb') as f:
            f.write(test_local_file_content)

        # Try to put test file to remote
        # This should raise an error with "No such file" as the directory
        # does not exist
        with self.assertRaises(Exception) as error:
            put_test_task = SFTPOperator(
                task_id="test_sftp",
                ssh_hook=self.hook,
                local_filepath=self.test_local_filepath,
                remote_filepath=self.test_remote_filepath_int_dir,
                operation=SFTPOperation.PUT,
                create_intermediate_dirs=False,
                dag=self.dag
            )
            self.assertIsNotNone(put_test_task)
            ti2 = TaskInstance(task=put_test_task, execution_date=timezone.utcnow())
            ti2.run()
        self.assertIn('No such file', str(error.exception))
예제 #6
0
    def test_backfill_pooled_tasks(self):
        """
        Test that queued tasks are executed by BackfillJob

        Test for https://github.com/airbnb/airflow/pull/1225
        """
        session = settings.Session()
        pool = Pool(pool='test_backfill_pooled_task_pool', slots=1)
        session.add(pool)
        session.commit()

        dag = self.dagbag.get_dag('test_backfill_pooled_task_dag')
        dag.clear()

        job = BackfillJob(
            dag=dag,
            start_date=DEFAULT_DATE,
            end_date=DEFAULT_DATE)

        # run with timeout because this creates an infinite loop if not
        # caught
        with timeout(seconds=30):
            job.run()

        ti = TI(
            task=dag.get_task('test_backfill_pooled_task'),
            execution_date=DEFAULT_DATE)
        ti.refresh_from_db()
        self.assertEqual(ti.state, State.SUCCESS)
    def test_bigquery_operator_extra_link(self, mock_hook):
        bigquery_task = BigQueryOperator(
            task_id=TASK_ID,
            sql='SELECT * FROM test_table',
            dag=self.dag,
        )
        self.dag.clear()

        ti = TaskInstance(
            task=bigquery_task,
            execution_date=DEFAULT_DATE,
        )

        job_id = '12345'
        ti.xcom_push(key='job_id', value=job_id)

        self.assertEquals(
            'https://console.cloud.google.com/bigquery?j={job_id}'.format(job_id=job_id),
            bigquery_task.get_extra_links(DEFAULT_DATE, BigQueryConsoleLink.name),
        )

        self.assertEquals(
            '',
            bigquery_task.get_extra_links(datetime(2019, 1, 1), BigQueryConsoleLink.name),
        )
예제 #8
0
    def test_cli_backfill_depends_on_past(self):
        """
        Test that CLI respects -I argument
        """
        dag_id = 'test_dagrun_states_deadlock'
        run_date = DEFAULT_DATE + datetime.timedelta(days=1)
        args = [
            'backfill',
            dag_id,
            '-l',
            '-s',
            run_date.isoformat(),
        ]
        dag = self.dagbag.get_dag(dag_id)
        dag.clear()

        self.assertRaisesRegexp(
            AirflowException,
            'BackfillJob is deadlocked',
            cli.backfill,
            self.parser.parse_args(args))

        cli.backfill(self.parser.parse_args(args + ['-I']))
        ti = TI(dag.get_task('test_depends_on_past'), run_date)
        ti.refresh_from_db()
        # task ran
        self.assertEqual(ti.state, State.SUCCESS)
        dag.clear()
예제 #9
0
    def test_xcom_push_flag(self):
        """
        Tests the option for Operators to push XComs
        """
        value = 'hello'
        task_id = 'test_no_xcom_push'
        dag = models.DAG(dag_id='test_xcom')

        # nothing saved to XCom
        task = PythonOperator(
            task_id=task_id,
            dag=dag,
            python_callable=lambda: value,
            do_xcom_push=False,
            owner='airflow',
            start_date=datetime.datetime(2017, 1, 1)
        )
        ti = TI(task=task, execution_date=datetime.datetime(2017, 1, 1))
        ti.run()
        self.assertEqual(
            ti.xcom_pull(
                task_ids=task_id, key=models.XCOM_RETURN_KEY
            ),
            None
        )
예제 #10
0
    def test_email_alert_with_config(self, mock_send_email):
        dag = models.DAG(dag_id='test_failure_email')
        task = BashOperator(
            task_id='test_email_alert_with_config',
            dag=dag,
            bash_command='exit 1',
            start_date=DEFAULT_DATE,
            email='to')

        ti = TI(
            task=task, execution_date=datetime.datetime.now())

        configuration.set('email', 'SUBJECT_TEMPLATE', '/subject/path')
        configuration.set('email', 'HTML_CONTENT_TEMPLATE', '/html_content/path')

        opener = mock_open(read_data='template: {{ti.task_id}}')
        with patch('airflow.models.taskinstance.open', opener, create=True):
            try:
                ti.run()
            except AirflowException:
                pass

        (email, title, body), _ = mock_send_email.call_args
        self.assertEqual(email, 'to')
        self.assertEqual('template: test_email_alert_with_config', title)
        self.assertEqual('template: test_email_alert_with_config', body)
예제 #11
0
파일: cli.py 프로젝트: seancron/airflow
def backfill(args, dag=None):
    logging.basicConfig(level=settings.LOGGING_LEVEL, format=settings.SIMPLE_LOG_FORMAT)

    dag = dag or get_dag(args)

    if not args.start_date and not args.end_date:
        raise AirflowException("Provide a start_date and/or end_date")

    # If only one date is passed, using same as start and end
    args.end_date = args.end_date or args.start_date
    args.start_date = args.start_date or args.end_date

    if args.task_regex:
        dag = dag.sub_dag(task_regex=args.task_regex, include_upstream=not args.ignore_dependencies)

    if args.dry_run:
        print("Dry run of DAG {0} on {1}".format(args.dag_id, args.start_date))
        for task in dag.tasks:
            print("Task {0}".format(task.task_id))
            ti = TaskInstance(task, args.start_date)
            ti.dry_run()
    else:
        dag.run(
            start_date=args.start_date,
            end_date=args.end_date,
            mark_success=args.mark_success,
            include_adhoc=args.include_adhoc,
            local=args.local,
            donot_pickle=(args.donot_pickle or conf.getboolean("core", "donot_pickle")),
            ignore_dependencies=args.ignore_dependencies,
            ignore_first_depends_on_past=args.ignore_first_depends_on_past,
            pool=args.pool,
        )
예제 #12
0
파일: models.py 프로젝트: ludovicc/airflow
    def test_post_execute_hook(self):
        """
        Test that post_execute hook is called with the Operator's result.
        The result ('error') will cause an error to be raised and trapped.
        """

        class TestError(Exception):
            pass

        class TestOperator(PythonOperator):
            def post_execute(self, context, result):
                if result == 'error':
                    raise TestError('expected error.')

        dag = models.DAG(dag_id='test_post_execute_dag')
        task = TestOperator(
            task_id='test_operator',
            dag=dag,
            python_callable=lambda: 'error',
            owner='airflow',
            start_date=datetime.datetime(2017, 2, 1))
        ti = TI(task=task, execution_date=datetime.datetime.now())

        with self.assertRaises(TestError):
            ti.run()
예제 #13
0
    def test_clear_task_instances_without_task(self):
        dag = DAG('test_clear_task_instances_without_task', start_date=DEFAULT_DATE,
                  end_date=DEFAULT_DATE + datetime.timedelta(days=10))
        task0 = DummyOperator(task_id='task0', owner='test', dag=dag)
        task1 = DummyOperator(task_id='task1', owner='test', dag=dag, retries=2)
        ti0 = TI(task=task0, execution_date=DEFAULT_DATE)
        ti1 = TI(task=task1, execution_date=DEFAULT_DATE)
        ti0.run()
        ti1.run()

        # Remove the task from dag.
        dag.task_dict = {}
        self.assertFalse(dag.has_task(task0.task_id))
        self.assertFalse(dag.has_task(task1.task_id))

        session = settings.Session()
        qry = session.query(TI).filter(
            TI.dag_id == dag.dag_id).all()
        clear_task_instances(qry, session)
        session.commit()
        # When dag is None, max_tries will be maximum of original max_tries or try_number.
        ti0.refresh_from_db()
        ti1.refresh_from_db()
        # Next try to run will be try 2
        self.assertEqual(ti0.try_number, 2)
        self.assertEqual(ti0.max_tries, 1)
        self.assertEqual(ti1.try_number, 2)
        self.assertEqual(ti1.max_tries, 2)
예제 #14
0
파일: jobs.py 프로젝트: slvwolf/airflow
    def test_scheduler_pooled_tasks(self):
        """
        Test that the scheduler handles queued tasks correctly
        See issue #1299
        """
        session = settings.Session()
        if not (
                session.query(Pool)
                .filter(Pool.pool == 'test_queued_pool')
                .first()):
            pool = Pool(pool='test_queued_pool', slots=5)
            session.merge(pool)
            session.commit()
        session.close()

        dag_id = 'test_scheduled_queued_tasks'
        dag = self.dagbag.get_dag(dag_id)
        dag.clear()

        scheduler = SchedulerJob(dag_id, num_runs=10)
        scheduler.run()

        task_1 = dag.tasks[0]
        ti = TI(task_1, dag.start_date)
        ti.refresh_from_db()
        self.assertEqual(ti.state, State.FAILED)

        dag.clear()
예제 #15
0
 def test_check_and_change_state_before_execution_dep_not_met(self):
     dag = models.DAG(dag_id='test_check_and_change_state_before_execution')
     task = DummyOperator(task_id='task', dag=dag, start_date=DEFAULT_DATE)
     task2 = DummyOperator(task_id='task2', dag=dag, start_date=DEFAULT_DATE)
     task >> task2
     ti = TI(
         task=task2, execution_date=timezone.utcnow())
     self.assertFalse(ti._check_and_change_state_before_execution())
예제 #16
0
    def test_overwrite_params_with_dag_run_none(self):
        task = DummyOperator(task_id='op')
        ti = TI(task=task, execution_date=datetime.datetime.now())
        params = {"override": False}

        ti.overwrite_params_with_dag_run_conf(params, None)

        self.assertEqual(False, params["override"])
    def test_s3_to_sftp_operation(self):
        # Setting
        configuration.conf.set("core", "enable_xcom_pickling", "True")
        test_remote_file_content = \
            "This is remote file content \n which is also multiline " \
            "another line here \n this is last line. EOF"

        # Test for creation of s3 bucket
        conn = boto3.client('s3')
        conn.create_bucket(Bucket=self.s3_bucket)
        self.assertTrue((self.s3_hook.check_for_bucket(self.s3_bucket)))

        with open(LOCAL_FILE_PATH, 'w') as f:
            f.write(test_remote_file_content)
        self.s3_hook.load_file(LOCAL_FILE_PATH, self.s3_key, bucket_name=BUCKET)

        # Check if object was created in s3
        objects_in_dest_bucket = conn.list_objects(Bucket=self.s3_bucket,
                                                   Prefix=self.s3_key)
        # there should be object found, and there should only be one object found
        self.assertEqual(len(objects_in_dest_bucket['Contents']), 1)

        # the object found should be consistent with dest_key specified earlier
        self.assertEqual(objects_in_dest_bucket['Contents'][0]['Key'], self.s3_key)

        # get remote file to local
        run_task = S3ToSFTPOperator(
            s3_bucket=BUCKET,
            s3_key=S3_KEY,
            sftp_path=SFTP_PATH,
            sftp_conn_id=SFTP_CONN_ID,
            s3_conn_id=S3_CONN_ID,
            task_id=TASK_ID,
            dag=self.dag
        )
        self.assertIsNotNone(run_task)

        run_task.execute(None)

        # Check that the file is created remotely
        check_file_task = SSHOperator(
            task_id="test_check_file",
            ssh_hook=self.hook,
            command="cat {0}".format(self.sftp_path),
            do_xcom_push=True,
            dag=self.dag
        )
        self.assertIsNotNone(check_file_task)
        ti3 = TaskInstance(task=check_file_task, execution_date=timezone.utcnow())
        ti3.run()
        self.assertEqual(
            ti3.xcom_pull(task_ids='test_check_file', key='return_value').strip(),
            test_remote_file_content.encode('utf-8'))

        # Clean up after finishing with test
        conn.delete_object(Bucket=self.s3_bucket, Key=self.s3_key)
        conn.delete_bucket(Bucket=self.s3_bucket)
        self.assertFalse((self.s3_hook.check_for_bucket(self.s3_bucket)))
예제 #18
0
    def test_overwrite_params_with_dag_run_conf(self):
        task = DummyOperator(task_id='op')
        ti = TI(task=task, execution_date=datetime.datetime.now())
        dag_run = DagRun()
        dag_run.conf = {"override": True}
        params = {"override": False}

        ti.overwrite_params_with_dag_run_conf(params, dag_run)

        self.assertEqual(True, params["override"])
예제 #19
0
파일: cli.py 프로젝트: johnw424/airflow
def test(args):
    log_to_stdout()
    args.execution_date = dateutil.parser.parse(args.execution_date)
    dagbag = DagBag(args.subdir)
    if args.dag_id not in dagbag.dags:
        raise AirflowException('dag_id could not be found')
    dag = dagbag.dags[args.dag_id]
    task = dag.get_task(task_id=args.task_id)
    ti = TaskInstance(task, args.execution_date)
    ti.run(force=True, ignore_dependencies=True, test_mode=True)
예제 #20
0
파일: models.py 프로젝트: ludovicc/airflow
    def test_requeue_over_concurrency(self, mock_concurrency_reached):
        mock_concurrency_reached.return_value = True

        dag = DAG(dag_id='test_requeue_over_concurrency', start_date=DEFAULT_DATE,
                  max_active_runs=1, concurrency=2)
        task = DummyOperator(task_id='test_requeue_over_concurrency_op', dag=dag)

        ti = TI(task=task, execution_date=datetime.datetime.now())
        ti.run()
        self.assertEqual(ti.state, models.State.NONE)
예제 #21
0
 def test_check_and_change_state_before_execution(self):
     dag = models.DAG(dag_id='test_check_and_change_state_before_execution')
     task = DummyOperator(task_id='task', dag=dag, start_date=DEFAULT_DATE)
     ti = TI(
         task=task, execution_date=timezone.utcnow())
     self.assertEqual(ti._try_number, 0)
     self.assertTrue(ti._check_and_change_state_before_execution())
     # State should be running, and try_number column should be incremented
     self.assertEqual(ti.state, State.RUNNING)
     self.assertEqual(ti._try_number, 1)
예제 #22
0
 def test_set_duration(self):
     task = DummyOperator(task_id='op', email='*****@*****.**')
     ti = TI(
         task=task,
         execution_date=datetime.datetime.now(),
     )
     ti.start_date = datetime.datetime(2018, 10, 1, 1)
     ti.end_date = datetime.datetime(2018, 10, 1, 2)
     ti.set_duration()
     self.assertEqual(ti.duration, 3600)
    def test_render_template(self):
        ti = TaskInstance(self.mock_operator, DEFAULT_DATE)
        ti.render_templates()

        expected_rendered_template = {'$lt': '2017-01-01T00:00:00+00:00Z'}

        self.assertDictEqual(
            expected_rendered_template,
            getattr(self.mock_operator, 'mongo_query')
        )
예제 #24
0
    def test_retry_handling(self, mock_pool_full):
        """
        Test that task retries are handled properly
        """
        # Mock the pool with a pool with slots open since the pool doesn't actually exist
        mock_pool_full.return_value = False

        dag = models.DAG(dag_id='test_retry_handling')
        task = BashOperator(
            task_id='test_retry_handling_op',
            bash_command='exit 1',
            retries=1,
            retry_delay=datetime.timedelta(seconds=0),
            dag=dag,
            owner='airflow',
            start_date=timezone.datetime(2016, 2, 1, 0, 0, 0))

        def run_with_error(ti):
            try:
                ti.run()
            except AirflowException:
                pass

        ti = TI(
            task=task, execution_date=timezone.utcnow())
        self.assertEqual(ti.try_number, 1)

        # first run -- up for retry
        run_with_error(ti)
        self.assertEqual(ti.state, State.UP_FOR_RETRY)
        self.assertEqual(ti._try_number, 1)
        self.assertEqual(ti.try_number, 2)

        # second run -- fail
        run_with_error(ti)
        self.assertEqual(ti.state, State.FAILED)
        self.assertEqual(ti._try_number, 2)
        self.assertEqual(ti.try_number, 3)

        # Clear the TI state since you can't run a task with a FAILED state without
        # clearing it first
        dag.clear()

        # third run -- up for retry
        run_with_error(ti)
        self.assertEqual(ti.state, State.UP_FOR_RETRY)
        self.assertEqual(ti._try_number, 3)
        self.assertEqual(ti.try_number, 4)

        # fourth run -- fail
        run_with_error(ti)
        ti.refresh_from_db()
        self.assertEqual(ti.state, State.FAILED)
        self.assertEqual(ti._try_number, 4)
        self.assertEqual(ti.try_number, 5)
예제 #25
0
파일: cli.py 프로젝트: TuneOSS/airflow
def task_state(args):
    """
    Returns the state of a TaskInstance at the command line.

    >>> airflow task_state tutorial sleep 2015-01-01
    success
    """
    dag = get_dag(args)
    task = dag.get_task(task_id=args.task_id)
    ti = TaskInstance(task, args.execution_date)
    print(ti.current_state())
 def test_operation_get_with_templates(self, _):
     dag_id = 'test_dag_id'
     configuration.load_test_config()
     args = {'start_date': DEFAULT_DATE}
     self.dag = DAG(dag_id, default_args=args)
     op = GcpTransferServiceOperationGetOperator(
         operation_name='{{ dag.dag_id }}', task_id='task-id', dag=self.dag
     )
     ti = TaskInstance(op, DEFAULT_DATE)
     ti.render_templates()
     self.assertEqual(dag_id, getattr(op, 'operation_name'))
예제 #27
0
파일: cli.py 프로젝트: TuneOSS/airflow
def render(args):
    dag = get_dag(args)
    task = dag.get_task(task_id=args.task_id)
    ti = TaskInstance(task, args.execution_date)
    ti.render_templates()
    for attr in task.__class__.template_fields:
        print(textwrap.dedent("""\
        # ----------------------------------------------------------
        # property: {}
        # ----------------------------------------------------------
        {}
        """.format(attr, getattr(task, attr))))
예제 #28
0
 def delete_remote_resource(self):
     # check the remote file content
     remove_file_task = SSHOperator(
             task_id="test_check_file",
             ssh_hook=self.hook,
             command="rm {0}".format(self.test_remote_filepath),
             do_xcom_push=True,
             dag=self.dag
     )
     self.assertIsNotNone(remove_file_task)
     ti3 = TaskInstance(task=remove_file_task, execution_date=datetime.now())
     ti3.run()
예제 #29
0
    def execute(self, context):
        # If the DAG Run is externally triggered, then return without
        # skipping downstream tasks
        if context['dag_run'] and context['dag_run'].external_trigger:
            logging.info("""Externally triggered DAG_Run:
                         allowing execution to proceed.""")
            return

        now = datetime.datetime.now()
        left_window = context['dag'].following_schedule(
            context['execution_date'])
        right_window = context['dag'].following_schedule(left_window)
        logging.info(
            'Checking latest only with left_window: %s right_window: %s '
            'now: %s', left_window, right_window, now)

        if not left_window < now <= right_window:
            logging.info('Not latest execution, skipping downstream.')
            session = settings.Session()

            TI = TaskInstance
            tis = session.query(TI).filter(
                TI.execution_date == context['ti'].execution_date,
                TI.task_id.in_(context['task'].downstream_task_ids)
            ).with_for_update().all()

            for ti in tis:
                logging.info('Skipping task: %s', ti.task_id)
                ti.state = State.SKIPPED
                ti.start_date = now
                ti.end_date = now
                session.merge(ti)

            # this is defensive against dag runs that are not complete
            for task in context['task'].downstream_list:
                if task.task_id in tis:
                    continue

                logging.warning("Task {} was not part of a dag run. "
                                "This should not happen."
                                .format(task))
                now = datetime.datetime.now()
                ti = TaskInstance(task, execution_date=context['ti'].execution_date)
                ti.state = State.SKIPPED
                ti.start_date = now
                ti.end_date = now
                session.merge(ti)

            session.commit()
            session.close()
            logging.info('Done.')
        else:
            logging.info('Latest, allowing execution to proceed.')
    def test_file_task_handler(self):
        def task_callable(ti, **kwargs):
            ti.log.info("test")
        dag = DAG('dag_for_testing_file_task_handler', start_date=DEFAULT_DATE)
        task = PythonOperator(
            task_id='task_for_testing_file_log_handler',
            dag=dag,
            python_callable=task_callable,
            provide_context=True
        )
        ti = TaskInstance(task=task, execution_date=DEFAULT_DATE)

        logger = ti.log
        ti.log.disabled = False

        file_handler = next((handler for handler in logger.handlers
                             if handler.name == FILE_TASK_HANDLER), None)
        self.assertIsNotNone(file_handler)

        set_context(logger, ti)
        self.assertIsNotNone(file_handler.handler)
        # We expect set_context generates a file locally.
        log_filename = file_handler.handler.baseFilename
        self.assertTrue(os.path.isfile(log_filename))
        self.assertTrue(log_filename.endswith("1.log"), log_filename)

        ti.run(ignore_ti_state=True)

        file_handler.flush()
        file_handler.close()

        self.assertTrue(hasattr(file_handler, 'read'))
        # Return value of read must be a tuple of list and list.
        logs, metadatas = file_handler.read(ti)
        self.assertTrue(isinstance(logs, list))
        self.assertTrue(isinstance(metadatas, list))
        self.assertEqual(len(logs), 1)
        self.assertEqual(len(logs), len(metadatas))
        self.assertTrue(isinstance(metadatas[0], dict))
        target_re = r'\n\[[^\]]+\] {test_log_handlers.py:\d+} INFO - test\n'

        # We should expect our log line from the callable above to appear in
        # the logs we read back
        six.assertRegex(
            self,
            logs[0],
            target_re,
            "Logs were " + str(logs)
        )

        # Remove the generated tmp log file.
        os.remove(log_filename)
예제 #31
0
    def test_xcom_pull_different_execution_date(self):
        """
        tests xcom fetch behavior with different execution dates, using
        both xcom_pull with "include_prior_dates" and without
        """
        key = 'xcom_key'
        value = 'xcom_value'

        dag = models.DAG(dag_id='test_xcom', schedule_interval='@monthly')
        task = DummyOperator(task_id='test_xcom',
                             dag=dag,
                             pool='test_xcom',
                             owner='airflow',
                             start_date=datetime.datetime(2016, 6, 2, 0, 0, 0))
        exec_date = datetime.datetime.now()
        ti = TI(task=task, execution_date=exec_date)
        ti.run(mark_success=True)
        ti.xcom_push(key=key, value=value)
        self.assertEqual(ti.xcom_pull(task_ids='test_xcom', key=key), value)
        ti.run()
        exec_date += datetime.timedelta(days=1)
        ti = TI(task=task, execution_date=exec_date)
        ti.run()
        # We have set a new execution date (and did not pass in
        # 'include_prior_dates'which means this task should now have a cleared
        # xcom value
        self.assertEqual(ti.xcom_pull(task_ids='test_xcom', key=key), None)
        # We *should* get a value using 'include_prior_dates'
        self.assertEqual(
            ti.xcom_pull(task_ids='test_xcom',
                         key=key,
                         include_prior_dates=True), value)
예제 #32
0
# make sure there isn't already a file where this test will be writing to
if os.path.exists(dir + "/" + filename):
    os.remove(dir + "/" + filename)

# init the dag for the tests
dag = DAG(dag_id='anydag', start_date=datetime.now())

# init the Operator and Task Instance
task = DownloadFileOperator(task_id="get_data",
                            file_url=file_url,
                            dir=dir,
                            filename=filename,
                            dag=dag)

ti = TaskInstance(task=task, execution_date=datetime.now())


def test_get_data_from_http_request():
    # checks if the operator makes a simple http get correctly
    assert task.get_data_from_http_request(ti).text == file_url_content


def test_execute():
    # checks all keys are present in operator output
    output = task.execute(ti.get_template_context())
    keys = output.keys()
    assert "path" in keys
    assert "last_modified" in keys
    assert "checksum" in keys
예제 #33
0
def get_xcom_value(task_instance: TaskInstance):
    return task_instance.xcom_pull(task_ids=task_instance.task_id)
예제 #34
0
def create_task_instance(
    task: BaseOperator,
    execution_date: pendulum.datetime = DEFAULT_EXECUTION_DATE,
) -> TaskInstance:
    return TaskInstance(task=task, execution_date=execution_date)
예제 #35
0
def dag_backfill(args, dag=None):
    """Creates backfill job or dry run for a DAG"""
    logging.basicConfig(level=settings.LOGGING_LEVEL,
                        format=settings.SIMPLE_LOG_FORMAT)

    signal.signal(signal.SIGTERM, sigint_handler)

    import warnings
    warnings.warn(
        '--ignore-first-depends-on-past is deprecated as the value is always set to True',
        category=PendingDeprecationWarning)

    if args.ignore_first_depends_on_past is False:
        args.ignore_first_depends_on_past = True

    dag = dag or get_dag(args.subdir, args.dag_id)

    if not args.start_date and not args.end_date:
        raise AirflowException("Provide a start_date and/or end_date")

    # If only one date is passed, using same as start and end
    args.end_date = args.end_date or args.start_date
    args.start_date = args.start_date or args.end_date

    if args.task_regex:
        dag = dag.sub_dag(task_regex=args.task_regex,
                          include_upstream=not args.ignore_dependencies)

    run_conf = None
    if args.conf:
        run_conf = json.loads(args.conf)

    if args.dry_run:
        print("Dry run of DAG {0} on {1}".format(args.dag_id, args.start_date))
        for task in dag.tasks:
            print("Task {0}".format(task.task_id))
            ti = TaskInstance(task, args.start_date)
            ti.dry_run()
    else:
        if args.reset_dagruns:
            DAG.clear_dags(
                [dag],
                start_date=args.start_date,
                end_date=args.end_date,
                confirm_prompt=not args.yes,
                include_subdags=True,
            )

        dag.run(start_date=args.start_date,
                end_date=args.end_date,
                mark_success=args.mark_success,
                local=args.local,
                donot_pickle=(args.donot_pickle
                              or conf.getboolean('core', 'donot_pickle')),
                ignore_first_depends_on_past=args.ignore_first_depends_on_past,
                ignore_task_deps=args.ignore_dependencies,
                pool=args.pool,
                delay_on_limit_secs=args.delay_on_limit,
                verbose=args.verbose,
                conf=run_conf,
                rerun_failed_tasks=args.rerun_failed_tasks,
                run_backwards=args.run_backwards)
 def _get_task_instance(self, state):
     dag = DAG('test_dag')
     task = Mock(dag=dag)
     ti = TaskInstance(task=task, state=state, execution_date=None)
     return ti
예제 #37
0
    def test_update_counters(self):
        dag = DAG(dag_id='test_manage_executor_state', start_date=DEFAULT_DATE)

        task1 = DummyOperator(task_id='dummy', dag=dag, owner='airflow')

        job = BackfillJob(dag=dag)

        session = settings.Session()
        dr = dag.create_dagrun(run_id=DagRunType.SCHEDULED.value,
                               state=State.RUNNING,
                               execution_date=DEFAULT_DATE,
                               start_date=DEFAULT_DATE,
                               session=session)
        ti = TI(task1, dr.execution_date)
        ti.refresh_from_db()

        ti_status = BackfillJob._DagRunTaskStatus()

        # test for success
        ti.set_state(State.SUCCESS, session)
        ti_status.running[ti.key] = ti
        job._update_counters(ti_status=ti_status)
        self.assertTrue(len(ti_status.running) == 0)
        self.assertTrue(len(ti_status.succeeded) == 1)
        self.assertTrue(len(ti_status.skipped) == 0)
        self.assertTrue(len(ti_status.failed) == 0)
        self.assertTrue(len(ti_status.to_run) == 0)

        ti_status.succeeded.clear()

        # test for skipped
        ti.set_state(State.SKIPPED, session)
        ti_status.running[ti.key] = ti
        job._update_counters(ti_status=ti_status)
        self.assertTrue(len(ti_status.running) == 0)
        self.assertTrue(len(ti_status.succeeded) == 0)
        self.assertTrue(len(ti_status.skipped) == 1)
        self.assertTrue(len(ti_status.failed) == 0)
        self.assertTrue(len(ti_status.to_run) == 0)

        ti_status.skipped.clear()

        # test for failed
        ti.set_state(State.FAILED, session)
        ti_status.running[ti.key] = ti
        job._update_counters(ti_status=ti_status)
        self.assertTrue(len(ti_status.running) == 0)
        self.assertTrue(len(ti_status.succeeded) == 0)
        self.assertTrue(len(ti_status.skipped) == 0)
        self.assertTrue(len(ti_status.failed) == 1)
        self.assertTrue(len(ti_status.to_run) == 0)

        ti_status.failed.clear()

        # test for retry
        ti.set_state(State.UP_FOR_RETRY, session)
        ti_status.running[ti.key] = ti
        job._update_counters(ti_status=ti_status)
        self.assertTrue(len(ti_status.running) == 0)
        self.assertTrue(len(ti_status.succeeded) == 0)
        self.assertTrue(len(ti_status.skipped) == 0)
        self.assertTrue(len(ti_status.failed) == 0)
        self.assertTrue(len(ti_status.to_run) == 1)

        ti_status.to_run.clear()

        # test for reschedule
        ti.set_state(State.UP_FOR_RESCHEDULE, session)
        ti_status.running[ti.key] = ti
        job._update_counters(ti_status=ti_status)
        self.assertTrue(len(ti_status.running) == 0)
        self.assertTrue(len(ti_status.succeeded) == 0)
        self.assertTrue(len(ti_status.skipped) == 0)
        self.assertTrue(len(ti_status.failed) == 0)
        self.assertTrue(len(ti_status.to_run) == 1)

        ti_status.to_run.clear()

        # test for none
        ti.set_state(State.NONE, session)
        ti_status.running[ti.key] = ti
        job._update_counters(ti_status=ti_status)
        self.assertTrue(len(ti_status.running) == 0)
        self.assertTrue(len(ti_status.succeeded) == 0)
        self.assertTrue(len(ti_status.skipped) == 0)
        self.assertTrue(len(ti_status.failed) == 0)
        self.assertTrue(len(ti_status.to_run) == 1)

        ti_status.to_run.clear()

        session.close()
def test_create_interval_metrics():
    ti = TaskInstance(task=create_interval_metrics, execution_date=datetime.now())
    result = create_interval_metrics.execute(ti.get_template_context())
    print("---")
    assert result == "succeeded"
예제 #39
0
파일: cli.py 프로젝트: zhjchen/airflow
def run(args):

    utils.pessimistic_connection_handling()
    # Setting up logging
    log = os.path.expanduser(conf.get('core', 'BASE_LOG_FOLDER'))
    directory = log + "/{args.dag_id}/{args.task_id}".format(args=args)
    if not os.path.exists(directory):
        os.makedirs(directory)
    args.execution_date = dateutil.parser.parse(args.execution_date)
    iso = args.execution_date.isoformat()
    filename = "{directory}/{iso}".format(**locals())

    # store old log (to help with S3 appends)
    if os.path.exists(filename):
        with open(filename, 'r') as logfile:
            old_log = logfile.read()
    else:
        old_log = None

    subdir = None
    if args.subdir:
        subdir = args.subdir.replace("DAGS_FOLDER",
                                     conf.get("core", "DAGS_FOLDER"))
        subdir = os.path.expanduser(subdir)
    logging.basicConfig(filename=filename,
                        level=settings.LOGGING_LEVEL,
                        format=settings.LOG_FORMAT)
    if not args.pickle:
        dagbag = DagBag(subdir)
        if args.dag_id not in dagbag.dags:
            msg = 'DAG [{0}] could not be found'.format(args.dag_id)
            logging.error(msg)
            raise AirflowException(msg)
        dag = dagbag.dags[args.dag_id]
        task = dag.get_task(task_id=args.task_id)
    else:
        session = settings.Session()
        logging.info('Loading pickle id {args.pickle}'.format(**locals()))
        dag_pickle = session.query(DagPickle).filter(
            DagPickle.id == args.pickle).first()
        if not dag_pickle:
            raise AirflowException("Who hid the pickle!? [missing pickle]")
        dag = dag_pickle.pickle
        task = dag.get_task(task_id=args.task_id)

    task_start_date = None
    if args.task_start_date:
        task_start_date = dateutil.parser.parse(args.task_start_date)
        task.start_date = task_start_date
    ti = TaskInstance(task, args.execution_date)

    if args.local:
        print("Logging into: " + filename)
        run_job = jobs.LocalTaskJob(
            task_instance=ti,
            mark_success=args.mark_success,
            force=args.force,
            pickle_id=args.pickle,
            task_start_date=task_start_date,
            ignore_dependencies=args.ignore_dependencies)
        run_job.run()
    elif args.raw:
        ti.run(
            mark_success=args.mark_success,
            force=args.force,
            ignore_dependencies=args.ignore_dependencies,
            job_id=args.job_id,
        )
    else:
        pickle_id = None
        if args.ship_dag:
            try:
                # Running remotely, so pickling the DAG
                session = settings.Session()
                pickle = DagPickle(dag)
                session.add(pickle)
                session.commit()
                pickle_id = pickle.id
                print(('Pickled dag {dag} '
                       'as pickle_id:{pickle_id}').format(**locals()))
            except Exception as e:
                print('Could not pickle the DAG')
                print(e)
                raise e

        executor = DEFAULT_EXECUTOR
        executor.start()
        print("Sending to executor.")
        executor.queue_task_instance(
            ti,
            mark_success=args.mark_success,
            pickle_id=pickle_id,
            ignore_dependencies=args.ignore_dependencies,
            force=args.force)
        executor.heartbeat()
        executor.end()

    if conf.get('core', 'S3_LOG_FOLDER').startswith('s3:'):
        import boto
        s3_log = filename.replace(log, conf.get('core', 'S3_LOG_FOLDER'))
        bucket, key = s3_log.lstrip('s3:/').split('/', 1)
        if os.path.exists(filename):

            # get logs
            with open(filename, 'r') as logfile:
                new_log = logfile.read()

            # remove old logs (since they are already in S3)
            if old_log:
                new_log.replace(old_log, '')

            try:
                s3 = boto.connect_s3()
                s3_key = boto.s3.key.Key(s3.get_bucket(bucket), key)

                # append new logs to old S3 logs, if available
                if s3_key.exists():
                    old_s3_log = s3_key.get_contents_as_string().decode()
                    new_log = old_s3_log + '\n' + new_log

                # send log to S3
                s3_key.set_contents_from_string(new_log)
            except:
                print('Could not send logs to S3.')
 def setUp(self):
     dag = DAG('dag_for_testing_filename_rendering',
               start_date=DEFAULT_DATE)
     task = DummyOperator(task_id='task_for_testing_filename_rendering',
                          dag=dag)
     self.ti = TaskInstance(task=task, execution_date=DEFAULT_DATE)
예제 #41
0
 def _get_task_instance(self, execution_date, dag_end_date=None, task_end_date=None):
     dag = Mock(end_date=dag_end_date)
     task = Mock(dag=dag, end_date=task_end_date)
     return TaskInstance(task=task, execution_date=execution_date)
예제 #42
0
    def test_next_retry_datetime(self):
        delay = datetime.timedelta(seconds=30)
        max_delay = datetime.timedelta(minutes=60)

        dag = models.DAG(dag_id='fail_dag')
        task = BashOperator(
            task_id='task_with_exp_backoff_and_max_delay',
            bash_command='exit 1',
            retries=3,
            retry_delay=delay,
            retry_exponential_backoff=True,
            max_retry_delay=max_delay,
            dag=dag,
            owner='airflow',
            start_date=timezone.datetime(2016, 2, 1, 0, 0, 0))
        ti = TI(
            task=task, execution_date=DEFAULT_DATE)
        ti.end_date = pendulum.instance(timezone.utcnow())

        dt = ti.next_retry_datetime()
        # between 30 * 2^0.5 and 30 * 2^1 (15 and 30)
        period = ti.end_date.add(seconds=30) - ti.end_date.add(seconds=15)
        self.assertTrue(dt in period)

        ti.try_number = 3
        dt = ti.next_retry_datetime()
        # between 30 * 2^2 and 30 * 2^3 (120 and 240)
        period = ti.end_date.add(seconds=240) - ti.end_date.add(seconds=120)
        self.assertTrue(dt in period)

        ti.try_number = 5
        dt = ti.next_retry_datetime()
        # between 30 * 2^4 and 30 * 2^5 (480 and 960)
        period = ti.end_date.add(seconds=960) - ti.end_date.add(seconds=480)
        self.assertTrue(dt in period)

        ti.try_number = 9
        dt = ti.next_retry_datetime()
        self.assertEqual(dt, ti.end_date + max_delay)

        ti.try_number = 50
        dt = ti.next_retry_datetime()
        self.assertEqual(dt, ti.end_date + max_delay)
예제 #43
0
 def test_set_duration_empty_dates(self):
     task = DummyOperator(task_id='op', email='*****@*****.**')
     ti = TI(task=task, execution_date=datetime.datetime.now())
     ti.set_duration()
     self.assertIsNone(ti.duration)
예제 #44
0
    def test_run_ignores_all_dependencies(self):
        """
        Test that run respects ignore_all_dependencies
        """
        dag_id = 'test_run_ignores_all_dependencies'

        dag = self.dagbag.get_dag('test_run_ignores_all_dependencies')
        dag.clear()

        task0_id = 'test_run_dependent_task'
        args0 = [
            'tasks', 'run', '--ignore-all-dependencies', dag_id, task0_id,
            DEFAULT_DATE.isoformat()
        ]
        task_command.task_run(self.parser.parse_args(args0))
        ti_dependent0 = TaskInstance(task=dag.get_task(task0_id),
                                     execution_date=DEFAULT_DATE)

        ti_dependent0.refresh_from_db()
        assert ti_dependent0.state == State.FAILED

        task1_id = 'test_run_dependency_task'
        args1 = [
            'tasks',
            'run',
            '--ignore-all-dependencies',
            dag_id,
            task1_id,
            (DEFAULT_DATE + timedelta(days=1)).isoformat(),
        ]
        task_command.task_run(self.parser.parse_args(args1))

        ti_dependency = TaskInstance(task=dag.get_task(task1_id),
                                     execution_date=DEFAULT_DATE +
                                     timedelta(days=1))
        ti_dependency.refresh_from_db()
        assert ti_dependency.state == State.FAILED

        task2_id = 'test_run_dependent_task'
        args2 = [
            'tasks',
            'run',
            '--ignore-all-dependencies',
            dag_id,
            task2_id,
            (DEFAULT_DATE + timedelta(days=1)).isoformat(),
        ]
        task_command.task_run(self.parser.parse_args(args2))

        ti_dependent = TaskInstance(task=dag.get_task(task2_id),
                                    execution_date=DEFAULT_DATE +
                                    timedelta(days=1))
        ti_dependent.refresh_from_db()
        assert ti_dependent.state == State.SUCCESS
예제 #45
0
    def test_zombies_are_correctly_passed_to_dag_file_processor(self):
        """
        Check that the same set of zombies are passed to the dag
        file processors until the next zombie detection logic is invoked.
        """
        with conf_vars({
            ('scheduler', 'max_threads'): '1',
            ('core', 'load_examples'): 'False'
        }):
            dagbag = DagBag(
                os.path.join(TEST_DAG_FOLDER, 'test_example_bash_operator.py'))
            with create_session() as session:
                session.query(LJ).delete()
                dag = dagbag.get_dag('test_example_bash_operator')
                task = dag.get_task(task_id='run_this_last')

                ti = TI(task, DEFAULT_DATE, State.RUNNING)
                lj = LJ(ti)
                lj.state = State.SHUTDOWN
                lj.id = 1
                ti.job_id = lj.id

                session.add(lj)
                session.add(ti)
                session.commit()
                fake_zombies = [SimpleTaskInstance(ti)]

            class FakeDagFIleProcessor(DagFileProcessor):
                # This fake processor will return the zombies it received in constructor
                # as its processing result w/o actually parsing anything.
                def __init__(self, file_path, pickle_dags, dag_id_white_list,
                             zombies):
                    super(FakeDagFIleProcessor,
                          self).__init__(file_path, pickle_dags,
                                         dag_id_white_list, zombies)

                    self._result = zombies, 0

                def start(self):
                    pass

                @property
                def start_time(self):
                    return DEFAULT_DATE

                @property
                def pid(self):
                    return 1234

                @property
                def done(self):
                    return True

                @property
                def result(self):
                    return self._result

            def processor_factory(file_path, zombies):
                return FakeDagFIleProcessor(file_path, False, [], zombies)

            test_dag_path = os.path.join(TEST_DAG_FOLDER,
                                         'test_example_bash_operator.py')
            async_mode = 'sqlite' not in conf.get('core', 'sql_alchemy_conn')
            processor_agent = DagFileProcessorAgent(test_dag_path, [], 1,
                                                    processor_factory,
                                                    timedelta.max, async_mode)
            processor_agent.start()
            parsing_result = []
            if not async_mode:
                processor_agent.heartbeat()
            while not processor_agent.done:
                if not async_mode:
                    processor_agent.wait_until_finished()
                parsing_result.extend(processor_agent.harvest_simple_dags())

            self.assertEqual(len(fake_zombies), len(parsing_result))
            self.assertEqual(set([zombie.key for zombie in fake_zombies]),
                             set([result.key for result in parsing_result]))
예제 #46
0
    def test_handle_failure_callback_with_zombies_are_correctly_passed_to_dag_file_processor(
            self):
        """
        Check that the same set of failure callback with zombies are passed to the dag
        file processors until the next zombie detection logic is invoked.
        """
        test_dag_path = os.path.join(TEST_DAG_FOLDER,
                                     'test_example_bash_operator.py')
        with conf_vars({
            ('scheduler', 'parsing_processes'): '1',
            ('core', 'load_examples'): 'False'
        }):
            dagbag = DagBag(test_dag_path, read_dags_from_db=False)
            with create_session() as session:
                session.query(LJ).delete()
                dag = dagbag.get_dag('test_example_bash_operator')
                dag.sync_to_db()
                task = dag.get_task(task_id='run_this_last')

                ti = TI(task, DEFAULT_DATE, State.RUNNING)
                local_job = LJ(ti)
                local_job.state = State.SHUTDOWN
                session.add(local_job)
                session.commit()

                # TODO: If there was an actual Relationship between TI and Job
                # we wouldn't need this extra commit
                session.add(ti)
                ti.job_id = local_job.id
                session.commit()

                expected_failure_callback_requests = [
                    TaskCallbackRequest(
                        full_filepath=dag.full_filepath,
                        simple_task_instance=SimpleTaskInstance(ti),
                        msg="Message",
                    )
                ]

            test_dag_path = os.path.join(TEST_DAG_FOLDER,
                                         'test_example_bash_operator.py')

            child_pipe, parent_pipe = multiprocessing.Pipe()
            async_mode = 'sqlite' not in conf.get('core', 'sql_alchemy_conn')

            fake_processors = []

            def fake_processor_factory(*args, **kwargs):
                nonlocal fake_processors
                processor = FakeDagFileProcessorRunner._fake_dag_processor_factory(
                    *args, **kwargs)
                fake_processors.append(processor)
                return processor

            manager = DagFileProcessorManager(
                dag_directory=test_dag_path,
                max_runs=1,
                processor_factory=fake_processor_factory,
                processor_timeout=timedelta.max,
                signal_conn=child_pipe,
                dag_ids=[],
                pickle_dags=False,
                async_mode=async_mode,
            )

            self.run_processor_manager_one_loop(manager, parent_pipe)

            if async_mode:
                # Once for initial parse, and then again for the add_callback_to_queue
                assert len(fake_processors) == 2
                assert fake_processors[0]._file_path == test_dag_path
                assert fake_processors[0]._callback_requests == []
            else:
                assert len(fake_processors) == 1

            assert fake_processors[-1]._file_path == test_dag_path
            callback_requests = fake_processors[-1]._callback_requests
            assert {
                zombie.simple_task_instance.key
                for zombie in expected_failure_callback_requests
            } == {
                result.simple_task_instance.key
                for result in callback_requests
            }

            child_pipe.close()
            parent_pipe.close()
예제 #47
0
파일: cli.py 프로젝트: rijuk/airflow
def run(args):

    utils.pessimistic_connection_handling()

    # Setting up logging
    log_base = os.path.expanduser(conf.get('core', 'BASE_LOG_FOLDER'))
    directory = log_base + "/{args.dag_id}/{args.task_id}".format(args=args)
    if not os.path.exists(directory):
        os.makedirs(directory)
    args.execution_date = dateutil.parser.parse(args.execution_date)
    iso = args.execution_date.isoformat()
    filename = "{directory}/{iso}".format(**locals())

    subdir = process_subdir(args.subdir)
    logging.root.handlers = []
    logging.basicConfig(filename=filename,
                        level=settings.LOGGING_LEVEL,
                        format=settings.LOG_FORMAT)

    if not args.pickle:
        dagbag = DagBag(subdir)
        if args.dag_id not in dagbag.dags:
            msg = 'DAG [{0}] could not be found in {1}'.format(
                args.dag_id, subdir)
            logging.error(msg)
            raise AirflowException(msg)
        dag = dagbag.dags[args.dag_id]
        task = dag.get_task(task_id=args.task_id)
    else:
        session = settings.Session()
        logging.info('Loading pickle id {args.pickle}'.format(**locals()))
        dag_pickle = session.query(DagPickle).filter(
            DagPickle.id == args.pickle).first()
        if not dag_pickle:
            raise AirflowException("Who hid the pickle!? [missing pickle]")
        dag = dag_pickle.pickle
        task = dag.get_task(task_id=args.task_id)

    task_start_date = None
    if args.task_start_date:
        task_start_date = dateutil.parser.parse(args.task_start_date)
        task.start_date = task_start_date
    ti = TaskInstance(task, args.execution_date)

    if args.local:
        print("Logging into: " + filename)
        run_job = jobs.LocalTaskJob(
            task_instance=ti,
            mark_success=args.mark_success,
            force=args.force,
            pickle_id=args.pickle,
            task_start_date=task_start_date,
            ignore_dependencies=args.ignore_dependencies,
            pool=args.pool)
        run_job.run()
    elif args.raw:
        ti.run(
            mark_success=args.mark_success,
            force=args.force,
            ignore_dependencies=args.ignore_dependencies,
            job_id=args.job_id,
            pool=args.pool,
        )
    else:
        pickle_id = None
        if args.ship_dag:
            try:
                # Running remotely, so pickling the DAG
                session = settings.Session()
                pickle = DagPickle(dag)
                session.add(pickle)
                session.commit()
                pickle_id = pickle.id
                print(('Pickled dag {dag} '
                       'as pickle_id:{pickle_id}').format(**locals()))
            except Exception as e:
                print('Could not pickle the DAG')
                print(e)
                raise e

        executor = DEFAULT_EXECUTOR
        executor.start()
        print("Sending to executor.")
        executor.queue_task_instance(
            ti,
            mark_success=args.mark_success,
            pickle_id=pickle_id,
            ignore_dependencies=args.ignore_dependencies,
            force=args.force,
            pool=args.pool)
        executor.heartbeat()
        executor.end()

    # store logs remotely
    remote_base = conf.get('core', 'REMOTE_BASE_LOG_FOLDER')

    # deprecated as of March 2016
    if not remote_base and conf.get('core', 'S3_LOG_FOLDER'):
        warnings.warn(
            'The S3_LOG_FOLDER conf key has been replaced by '
            'REMOTE_BASE_LOG_FOLDER. Your conf still works but please '
            'update airflow.cfg to ensure future compatibility.',
            DeprecationWarning)
        remote_base = conf.get('core', 'S3_LOG_FOLDER')

    if os.path.exists(filename):
        # read log and remove old logs to get just the latest additions

        with open(filename, 'r') as logfile:
            log = logfile.read()

        remote_log_location = filename.replace(log_base, remote_base)
        # S3

        if remote_base.startswith('s3:/'):
            utils.S3Log().write(log, remote_log_location)
        # GCS
        elif remote_base.startswith('gs:/'):
            utils.GCSLog().write(log, remote_log_location, append=True)
        # Other
        elif remote_base:
            logging.error(
                'Unsupported remote log location: {}'.format(remote_base))
예제 #48
0
    def test_get_states_count_upstream_ti(self):
        """
        this test tests the helper function '_get_states_count_upstream_ti' as a unit and inside update_state
        """
        from airflow.ti_deps.dep_context import DepContext

        get_states_count_upstream_ti = TriggerRuleDep._get_states_count_upstream_ti
        session = settings.Session()
        now = timezone.utcnow()
        dag = DAG(
            'test_dagrun_with_pre_tis',
            start_date=DEFAULT_DATE,
            default_args={'owner': 'owner1'})

        with dag:
            op1 = DummyOperator(task_id='A')
            op2 = DummyOperator(task_id='B')
            op3 = DummyOperator(task_id='C')
            op4 = DummyOperator(task_id='D')
            op5 = DummyOperator(task_id='E', trigger_rule=TriggerRule.ONE_FAILED)

            op1.set_downstream([op2, op3])  # op1 >> op2, op3
            op4.set_upstream([op3, op2])  # op3, op2 >> op4
            op5.set_upstream([op2, op3, op4])  # (op2, op3, op4) >> op5

        clear_db_runs()
        dag.clear()
        dr = dag.create_dagrun(run_id='test_dagrun_with_pre_tis',
                               state=State.RUNNING,
                               execution_date=now,
                               start_date=now)

        ti_op1 = TaskInstance(task=dag.get_task(op1.task_id), execution_date=dr.execution_date)
        ti_op2 = TaskInstance(task=dag.get_task(op2.task_id), execution_date=dr.execution_date)
        ti_op3 = TaskInstance(task=dag.get_task(op3.task_id), execution_date=dr.execution_date)
        ti_op4 = TaskInstance(task=dag.get_task(op4.task_id), execution_date=dr.execution_date)
        ti_op5 = TaskInstance(task=dag.get_task(op5.task_id), execution_date=dr.execution_date)

        ti_op1.set_state(state=State.SUCCESS, session=session)
        ti_op2.set_state(state=State.FAILED, session=session)
        ti_op3.set_state(state=State.SUCCESS, session=session)
        ti_op4.set_state(state=State.SUCCESS, session=session)
        ti_op5.set_state(state=State.SUCCESS, session=session)

        session.commit()

        # check handling with cases that tasks are triggered from backfill with no finished tasks
        finished_tasks = DepContext().ensure_finished_tasks(ti_op2.task.dag, ti_op2.execution_date, session)
        self.assertEqual(get_states_count_upstream_ti(finished_tasks=finished_tasks, ti=ti_op2),
                         (1, 0, 0, 0, 1))
        finished_tasks = dr.get_task_instances(state=State.finished() + [State.UPSTREAM_FAILED],
                                               session=session)
        self.assertEqual(get_states_count_upstream_ti(finished_tasks=finished_tasks, ti=ti_op4),
                         (1, 0, 1, 0, 2))
        self.assertEqual(get_states_count_upstream_ti(finished_tasks=finished_tasks, ti=ti_op5),
                         (2, 0, 1, 0, 3))

        dr.update_state()
        self.assertEqual(State.SUCCESS, dr.state)
예제 #49
0
def run_task(task_instance: TaskInstance) -> State:
    task_instance._run_raw_task(test_mode=True)
    return task_instance.state
예제 #50
0
def run(args, dag=None):
    db_utils.pessimistic_connection_handling()
    if dag:
        args.dag_id = dag.dag_id

    # Setting up logging
    log_base = os.path.expanduser(conf.get('core', 'BASE_LOG_FOLDER'))
    directory = log_base + "/{args.dag_id}/{args.task_id}".format(args=args)
    if not os.path.exists(directory):
        os.makedirs(directory)
    iso = args.execution_date.isoformat()
    filename = "{directory}/{iso}".format(**locals())

    logging.root.handlers = []
    logging.basicConfig(
        filename=filename,
        level=settings.LOGGING_LEVEL,
        format=settings.LOG_FORMAT)

    if not args.pickle and not dag:
        dag = get_dag(args)
    elif not dag:
        session = settings.Session()
        logging.info('Loading pickle id {args.pickle}'.format(**locals()))
        dag_pickle = session.query(
            DagPickle).filter(DagPickle.id == args.pickle).first()
        if not dag_pickle:
            raise AirflowException("Who hid the pickle!? [missing pickle]")
        dag = dag_pickle.pickle
    task = dag.get_task(task_id=args.task_id)

    ti = TaskInstance(task, args.execution_date)

    if args.local:
        print("Logging into: " + filename)
        run_job = jobs.LocalTaskJob(
            task_instance=ti,
            mark_success=args.mark_success,
            pickle_id=args.pickle,
            ignore_all_deps=args.ignore_all_dependencies,
            ignore_depends_on_past=args.ignore_depends_on_past,
            ignore_task_deps=args.ignore_dependencies,
            ignore_ti_state=args.force,
            pool=args.pool)
        run_job.run()
    elif args.raw:
        ti.run(
            mark_success=args.mark_success,
            ignore_all_deps=args.ignore_all_dependencies,
            ignore_depends_on_past=args.ignore_depends_on_past,
            ignore_task_deps=args.ignore_dependencies,
            ignore_ti_state=args.force,
            job_id=args.job_id,
            pool=args.pool,
        )
    else:
        pickle_id = None
        if args.ship_dag:
            try:
                # Running remotely, so pickling the DAG
                session = settings.Session()
                pickle = DagPickle(dag)
                session.add(pickle)
                session.commit()
                pickle_id = pickle.id
                print((
                    'Pickled dag {dag} '
                    'as pickle_id:{pickle_id}').format(**locals()))
            except Exception as e:
                print('Could not pickle the DAG')
                print(e)
                raise e

        executor = DEFAULT_EXECUTOR
        executor.start()
        print("Sending to executor.")
        executor.queue_task_instance(
            ti,
            mark_success=args.mark_success,
            pickle_id=pickle_id,
            ignore_all_deps=args.ignore_all_dependencies,
            ignore_depends_on_past=args.ignore_depends_on_past,
            ignore_task_deps=args.ignore_dependencies,
            ignore_ti_state=args.force,
            pool=args.pool)
        executor.heartbeat()
        executor.end()

    # Force the log to flush, and set the handler to go back to normal so we
    # don't continue logging to the task's log file. The flush is important
    # because we subsequently read from the log to insert into S3 or Google
    # cloud storage.
    logging.root.handlers[0].flush()
    logging.root.handlers = []

    # store logs remotely
    remote_base = conf.get('core', 'REMOTE_BASE_LOG_FOLDER')

    # deprecated as of March 2016
    if not remote_base and conf.get('core', 'S3_LOG_FOLDER'):
        warnings.warn(
            'The S3_LOG_FOLDER conf key has been replaced by '
            'REMOTE_BASE_LOG_FOLDER. Your conf still works but please '
            'update airflow.cfg to ensure future compatibility.',
            DeprecationWarning)
        remote_base = conf.get('core', 'S3_LOG_FOLDER')

    if os.path.exists(filename):
        # read log and remove old logs to get just the latest additions

        with open(filename, 'r') as logfile:
            log = logfile.read()

        remote_log_location = filename.replace(log_base, remote_base)
        # S3
        if remote_base.startswith('s3:/'):
            logging_utils.S3Log().write(log, remote_log_location)
        # GCS
        elif remote_base.startswith('gs:/'):
            logging_utils.GCSLog().write(
                log,
                remote_log_location,
                append=True)
        # Other
        elif remote_base and remote_base != 'None':
            logging.error(
                'Unsupported remote log location: {}'.format(remote_base))
예제 #51
0
def run_dag_tasks(dag):
    for task in dag.tasks:
        ti = TaskInstance(task=task, execution_date=DEFAULT_DATE)
        ti.run(ignore_ti_state=True)
        assert ti.state == State.SUCCESS
예제 #52
0
    def test_xcom_pull(self):
        """
        Test xcom_pull, using different filtering methods.
        """
        dag = models.DAG(
            dag_id='test_xcom', schedule_interval='@monthly',
            start_date=timezone.datetime(2016, 6, 1, 0, 0, 0))

        exec_date = timezone.utcnow()

        # Push a value
        task1 = DummyOperator(task_id='test_xcom_1', dag=dag, owner='airflow')
        ti1 = TI(task=task1, execution_date=exec_date)
        ti1.xcom_push(key='foo', value='bar')

        # Push another value with the same key (but by a different task)
        task2 = DummyOperator(task_id='test_xcom_2', dag=dag, owner='airflow')
        ti2 = TI(task=task2, execution_date=exec_date)
        ti2.xcom_push(key='foo', value='baz')

        # Pull with no arguments
        result = ti1.xcom_pull()
        self.assertEqual(result, None)
        # Pull the value pushed most recently by any task.
        result = ti1.xcom_pull(key='foo')
        self.assertIn(result, 'baz')
        # Pull the value pushed by the first task
        result = ti1.xcom_pull(task_ids='test_xcom_1', key='foo')
        self.assertEqual(result, 'bar')
        # Pull the value pushed by the second task
        result = ti1.xcom_pull(task_ids='test_xcom_2', key='foo')
        self.assertEqual(result, 'baz')
        # Pull the values pushed by both tasks
        result = ti1.xcom_pull(
            task_ids=['test_xcom_1', 'test_xcom_2'], key='foo')
        self.assertEqual(result, ('bar', 'baz'))
예제 #53
0
    def test_dag_clear(self):
        dag = DAG('test_dag_clear',
                  start_date=DEFAULT_DATE,
                  end_date=DEFAULT_DATE + datetime.timedelta(days=10))
        task0 = DummyOperator(task_id='test_dag_clear_task_0',
                              owner='test',
                              dag=dag)
        ti0 = TI(task=task0, execution_date=DEFAULT_DATE)

        dag.create_dagrun(
            execution_date=ti0.execution_date,
            state=State.RUNNING,
            run_type=DagRunType.SCHEDULED,
        )

        # Next try to run will be try 1
        self.assertEqual(ti0.try_number, 1)
        ti0.run()
        self.assertEqual(ti0.try_number, 2)
        dag.clear()
        ti0.refresh_from_db()
        self.assertEqual(ti0.try_number, 2)
        self.assertEqual(ti0.state, State.NONE)
        self.assertEqual(ti0.max_tries, 1)

        task1 = DummyOperator(task_id='test_dag_clear_task_1',
                              owner='test',
                              dag=dag,
                              retries=2)
        ti1 = TI(task=task1, execution_date=DEFAULT_DATE)
        self.assertEqual(ti1.max_tries, 2)
        ti1.try_number = 1
        # Next try will be 2
        ti1.run()
        self.assertEqual(ti1.try_number, 3)
        self.assertEqual(ti1.max_tries, 2)

        dag.clear()
        ti0.refresh_from_db()
        ti1.refresh_from_db()
        # after clear dag, ti2 should show attempt 3 of 5
        self.assertEqual(ti1.max_tries, 4)
        self.assertEqual(ti1.try_number, 3)
        # after clear dag, ti1 should show attempt 2 of 2
        self.assertEqual(ti0.try_number, 2)
        self.assertEqual(ti0.max_tries, 1)
예제 #54
0
    def test_reschedule_handling(self, mock_pool_full):
        """
        Test that task reschedules are handled properly
        """
        # Return values of the python sensor callable, modified during tests
        done = False
        fail = False

        def callable():
            if fail:
                raise AirflowException()
            return done

        dag = models.DAG(dag_id='test_reschedule_handling')
        task = PythonSensor(
            task_id='test_reschedule_handling_sensor',
            poke_interval=0,
            mode='reschedule',
            python_callable=callable,
            retries=1,
            retry_delay=datetime.timedelta(seconds=0),
            dag=dag,
            owner='airflow',
            pool='test_pool',
            start_date=timezone.datetime(2016, 2, 1, 0, 0, 0))

        ti = TI(task=task, execution_date=timezone.utcnow())
        self.assertEqual(ti._try_number, 0)
        self.assertEqual(ti.try_number, 1)

        def run_ti_and_assert(run_date, expected_start_date, expected_end_date,
                              expected_duration,
                              expected_state, expected_try_number,
                              expected_task_reschedule_count):
            with freeze_time(run_date):
                try:
                    ti.run()
                except AirflowException:
                    if not fail:
                        raise
            ti.refresh_from_db()
            self.assertEqual(ti.state, expected_state)
            self.assertEqual(ti._try_number, expected_try_number)
            self.assertEqual(ti.try_number, expected_try_number + 1)
            self.assertEqual(ti.start_date, expected_start_date)
            self.assertEqual(ti.end_date, expected_end_date)
            self.assertEqual(ti.duration, expected_duration)
            trs = TaskReschedule.find_for_task_instance(ti)
            self.assertEqual(len(trs), expected_task_reschedule_count)

        date1 = timezone.utcnow()
        date2 = date1 + datetime.timedelta(minutes=1)
        date3 = date2 + datetime.timedelta(minutes=1)
        date4 = date3 + datetime.timedelta(minutes=1)

        # Run with multiple reschedules.
        # During reschedule the try number remains the same, but each reschedule is recorded.
        # The start date is expected to remain the initial date, hence the duration increases.
        # When finished the try number is incremented and there is no reschedule expected
        # for this try.

        done, fail = False, False
        run_ti_and_assert(date1, date1, date1, 0, State.UP_FOR_RESCHEDULE, 0, 1)

        done, fail = False, False
        run_ti_and_assert(date2, date1, date2, 60, State.UP_FOR_RESCHEDULE, 0, 2)

        done, fail = False, False
        run_ti_and_assert(date3, date1, date3, 120, State.UP_FOR_RESCHEDULE, 0, 3)

        done, fail = True, False
        run_ti_and_assert(date4, date1, date4, 180, State.SUCCESS, 1, 0)

        # Clear the task instance.
        dag.clear()
        ti.refresh_from_db()
        self.assertEqual(ti.state, State.NONE)
        self.assertEqual(ti._try_number, 1)

        # Run again after clearing with reschedules and a retry.
        # The retry increments the try number, and for that try no reschedule is expected.
        # After the retry the start date is reset, hence the duration is also reset.

        done, fail = False, False
        run_ti_and_assert(date1, date1, date1, 0, State.UP_FOR_RESCHEDULE, 1, 1)

        done, fail = False, True
        run_ti_and_assert(date2, date1, date2, 60, State.UP_FOR_RETRY, 2, 0)

        done, fail = False, False
        run_ti_and_assert(date3, date3, date3, 0, State.UP_FOR_RESCHEDULE, 2, 1)

        done, fail = True, False
        run_ti_and_assert(date4, date3, date4, 60, State.SUCCESS, 3, 0)
예제 #55
0
    for (word, count) in output:
        print("%s: %i" % (word, count))

    return get_spark_session().createDataFrame(counts)


class WordCountPySparkTask(PySparkTask):
    text = parameter.data
    counters = parameter.output

    python_script = relative_path(__file__, "spark_scripts/word_count.py")

    def application_args(self):
        return [self.text, self.counters]


with DAG(dag_id="dbnd_dag_with_spark", default_args=default_args) as dag_spark:
    # noinspection PyTypeChecker
    spark_task = WordCountPySparkTask(text="s3://dbnd/README.md")
    spark_op = spark_task.op
    # spark_result = word_count_inline("/tmp/sample.txt")
    # spark_op = spark_result.op

if __name__ == "__main__":
    ti = TaskInstance(spark_op, days_ago(0))
    ti.run(ignore_task_deps=True, ignore_ti_state=True, test_mode=True)
    # # #
    #
    # dag_spark.clear()
    # dag_spark.run(start_date=days_ago(0), end_date=days_ago(0))
예제 #56
0
    def test_xcom_pull_after_success(self):
        """
        tests xcom set/clear relative to a task in a 'success' rerun scenario
        """
        key = 'xcom_key'
        value = 'xcom_value'

        dag = models.DAG(dag_id='test_xcom', schedule_interval='@monthly')
        task = DummyOperator(task_id='test_xcom',
                             dag=dag,
                             pool='test_xcom',
                             owner='airflow',
                             start_date=datetime.datetime(2016, 6, 2, 0, 0, 0))
        exec_date = datetime.datetime.now()
        ti = TI(task=task, execution_date=exec_date)
        ti.run(mark_success=True)
        ti.xcom_push(key=key, value=value)
        self.assertEqual(ti.xcom_pull(task_ids='test_xcom', key=key), value)
        ti.run()
        # The second run and assert is to handle AIRFLOW-131 (don't clear on
        # prior success)
        self.assertEqual(ti.xcom_pull(task_ids='test_xcom', key=key), value)

        # Test AIRFLOW-703: Xcom shouldn't be cleared if the task doesn't
        # execute, even if dependencies are ignored
        ti.run(ignore_all_deps=True, mark_success=True)
        self.assertEqual(ti.xcom_pull(task_ids='test_xcom', key=key), value)
        # Xcom IS finally cleared once task has executed
        ti.run(ignore_all_deps=True)
        self.assertEqual(ti.xcom_pull(task_ids='test_xcom', key=key), None)
    def test_extra_serialized_field_and_multiple_operator_links(self):
        """
        Assert extra field exists & OperatorLinks defined in Plugins and inbuilt Operator Links.

        This tests also depends on GoogleLink() registered as a plugin
        in tests/plugins/test_plugin.py

        The function tests that if extra operator links are registered in plugin
        in ``operator_extra_links`` and the same is also defined in
        the Operator in ``BaseOperator.operator_extra_links``, it has the correct
        extra link.
        """
        test_date = datetime(2019, 8, 1)
        dag = DAG(dag_id='simple_dag', start_date=test_date)
        CustomOperator(task_id='simple_task',
                       dag=dag,
                       bash_command=["echo", "true"])

        serialized_dag = SerializedDAG.to_dict(dag)
        self.assertIn("bash_command", serialized_dag["dag"]["tasks"][0])

        dag = SerializedDAG.from_dict(serialized_dag)
        simple_task = dag.task_dict["simple_task"]
        self.assertEqual(getattr(simple_task, "bash_command"),
                         ["echo", "true"])

        #########################################################
        # Verify Operator Links work with Serialized Operator
        #########################################################
        # Check Serialized version of operator link only contains the inbuilt Op Link
        self.assertEqual(
            serialized_dag["dag"]["tasks"][0]["_operator_extra_links"], [
                {
                    'airflow.utils.tests.CustomBaseIndexOpLink': {
                        'index': 0
                    }
                },
                {
                    'airflow.utils.tests.CustomBaseIndexOpLink': {
                        'index': 1
                    }
                },
            ])

        # Test all the extra_links are set
        six.assertCountEqual(self, simple_task.extra_links, [
            'BigQuery Console #1', 'BigQuery Console #2', 'airflow', 'github',
            'google'
        ])

        ti = TaskInstance(task=simple_task, execution_date=test_date)
        ti.xcom_push('search_query', ["dummy_value_1", "dummy_value_2"])

        # Test Deserialized inbuilt link #1
        custom_inbuilt_link = simple_task.get_extra_links(
            test_date, "BigQuery Console #1")
        self.assertEqual(
            'https://console.cloud.google.com/bigquery?j=dummy_value_1',
            custom_inbuilt_link)

        # Test Deserialized inbuilt link #2
        custom_inbuilt_link = simple_task.get_extra_links(
            test_date, "BigQuery Console #2")
        self.assertEqual(
            'https://console.cloud.google.com/bigquery?j=dummy_value_2',
            custom_inbuilt_link)

        # Test Deserialized link registered via Airflow Plugin
        google_link_from_plugin = simple_task.get_extra_links(
            test_date, GoogleLink.name)
        self.assertEqual("https://www.google.com", google_link_from_plugin)
예제 #58
0
    def test_subdag_clear_parentdag_downstream_clear(self):
        dag = self.dagbag.get_dag('clear_subdag_test_dag')
        subdag_op_task = dag.get_task('daily_job')

        subdag = subdag_op_task.subdag

        executor = MockExecutor()
        job = BackfillJob(dag=dag,
                          start_date=DEFAULT_DATE,
                          end_date=DEFAULT_DATE,
                          executor=executor,
                          donot_pickle=True)

        with timeout(seconds=30):
            job.run()

        ti_subdag = TI(task=dag.get_task('daily_job'),
                       execution_date=DEFAULT_DATE)
        ti_subdag.refresh_from_db()
        self.assertEqual(ti_subdag.state, State.SUCCESS)

        ti_irrelevant = TI(task=dag.get_task('daily_job_irrelevant'),
                           execution_date=DEFAULT_DATE)
        ti_irrelevant.refresh_from_db()
        self.assertEqual(ti_irrelevant.state, State.SUCCESS)

        ti_downstream = TI(task=dag.get_task('daily_job_downstream'),
                           execution_date=DEFAULT_DATE)
        ti_downstream.refresh_from_db()
        self.assertEqual(ti_downstream.state, State.SUCCESS)

        sdag = subdag.sub_dag(task_regex='daily_job_subdag_task',
                              include_downstream=True,
                              include_upstream=False)

        sdag.clear(start_date=DEFAULT_DATE,
                   end_date=DEFAULT_DATE,
                   include_parentdag=True)

        ti_subdag.refresh_from_db()
        self.assertEqual(State.NONE, ti_subdag.state)

        ti_irrelevant.refresh_from_db()
        self.assertEqual(State.SUCCESS, ti_irrelevant.state)

        ti_downstream.refresh_from_db()
        self.assertEqual(State.NONE, ti_downstream.state)

        subdag.clear()
        dag.clear()
예제 #59
0
 def get_link(self, operator, dttm):
     ti = TaskInstance(task=operator, execution_date=dttm)
     search_query = ti.xcom_pull(task_ids=operator.task_id,
                                 key='search_query')
     return 'http://google.com/custom_base_link?search={}'.format(
         search_query)
def test_infer_predictions():
    ti = TaskInstance(task=infer_predictions, execution_date=datetime.now())
    result = infer_predictions.execute(ti.get_template_context())
    assert result == "succeeded"